[
  {
    "path": ".gitignore",
    "content": "**/*.pt\n**/checkpoints\n**/wget-log\n**/_build/\n**/*.ckpt\n**/outputs\n**/*.tar.gz\n**/playground\n**/wandb\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\ndataset/*\ntensorflow/my_graph/*\n.idea/\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nenv/\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\ntmp/\n*.egg-info/\n.installed.cfg\n*.egg\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*,cover\n.hypothesis/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# IPython Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# dotenv\n.env\n\n# virtualenv\nvenv/\n.venv/\nENV/\n\n# Spyder project settings\n.spyderproject\n\n# Rope project settings\n.ropeproject\n\n# vscode\n.vscode\n\n# Mac\n.DS_Store\n\n# vim\n*.swp\n\n# ckpt\n*.lock\n\n# data\n*.parquet\n\n\n# local logs\nlogs\nlog\noutputs\n.history\n\n*tensorboard\ntensorboard/\n\n# version file\nsiirl/_version.py\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "\n# Default list of files to exclude from checks.\n# Add any other paths that should be ignored by all hooks.\nexclude: |\n  (?x)^(\n      docs/.*|\n      build/.*\n  )$\n\nrepos:\n-   repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v5.0.0\n    hooks:\n    -   id: trailing-whitespace\n    -   id: end-of-file-fixer\n    -   id: check-yaml\n    -   id: check-added-large-files\n        args: [--maxkb=500]\n    -   id: check-case-conflict\n    -   id: check-executables-have-shebangs\n    -   id: check-merge-conflict\n    -   id: check-symlinks\n    -   id: detect-private-key\n\n-   repo: https://github.com/astral-sh/ruff-pre-commit\n    rev: v0.12.6\n    hooks:\n    -   id: ruff\n        args: [\"--fix\", \"--show-fixes\", \"--output-format=full\"]\n    -   id: ruff-format\n\n-   repo: https://github.com/codespell-project/codespell\n    rev: v2.4.0\n    hooks:\n    -   id: codespell\n        args:\n          - --skip=\"*.json,*.txt\"\n          - --ignore-words-list=nd,repostory\n"
  },
  {
    "path": ".readthedocs.yaml",
    "content": "# Read the Docs configuration file\n# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details\n\n# Required\nversion: 2\n\n# Set the OS, Python version, and other tools you might need\nbuild:\n  os: ubuntu-22.04\n  tools:\n    python: \"3.11\"\n\n# Build documentation in the \"docs/\" directory with Sphinx\nsphinx:\n   configuration: docs/conf.py\n\n# Optionally, but recommended,\n# declare the Python requirements required to build your documentation\n# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html\npython:\n  install:\n    - requirements: docs/requirements-docs.txt\n    - method: pip\n      path: .\n\n        "
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing to siiRL\n\nThank you for considering contributing to siiRL!\n\nWe welcome contributions in various forms, including but not limited to:\n- Reporting a bug\n- Submitting a fix\n- Discussing the current state of the code\n- Proposing new features\n- Becoming a maintainer\n- Review pull requests\n- Add/Improve documentation\n- ...\n\n## Getting Started\n\nTo get started, please fork the latest branch.\n\n### Reporting Bugs\n\nIf you find a bug, please open an issue on our GitHub repository. When you are creating a bug report, please include as many details as possible. Fill out the required template, detailed information helps us resolve issues faster.\n\n### Suggesting Enhancements\n\nIf you have an idea for a new feature or an enhancement to an existing one, please open an issue on our GitHub repository. This allows for a discussion with the community and the project maintainers.\n\n### Pull Requests\n\nWe actively welcome your pull requests.\n\n1. Fork the repo and create your branch from `main`.\n2. If you've added code that should be tested, add tests.\n3. If you've changed APIs, update the documentation.\n4. Ensure the test suite passes.\n5. Make sure your code lints.\n6. Issue that pull request!\n\n## Styleguides\n\n### Git Commit Messages\n\n- Use the present tense (\"Add feature\" not \"Added feature\").\n- Use the imperative mood (\"Move A to...\" not \"Moves A to...\").\n- Limit the first line to 72 characters or less.\n- Reference issues and pull requests liberally after the first line.\n\n<!-- ### Code Style\n\nWe use XXX for code formatting and XXX for linting. Before submitting your pull request, please make sure your code is formatted and linted.\n\n```bash\n# Auto-format your code\npip install black\nblack .\n\n# Lint your code\npip install ruff\nruff .\n``` -->\n\n## Any questions?\n\nDon't hesitate to contact us if you have any questions. You can reach out to us by opening an issue on GitHub.\n\nWe are excited to see your contributions! "
  },
  {
    "path": "LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README-zh.md",
    "content": "<div align=\"center\">\n  <img src=\"asset/sii.png\" width=\"100%\"/>\n  <br>\n</div>\n<br>\n\n<h1 align=\"center\">\nsiiRL: Shanghai Innovation Institute RL Framework for Advanced LLMs and Multi-Agent Systems\n</h1>\n\n<p align=\"center\">\n| <a href=\"https://arxiv.org/abs/2507.13833\"><b>📄 论文</b></a> | \n| <a href=\"https://siirl.readthedocs.io/en/latest/index.html\"><b>📚 文档</b></a> |\n| <a href=\"asset/siiRL-feishu-group.png\">\n    <img src=\"asset/logo-feishu.png\" alt=\"Feishu Group QR Code\" height=\"15\" /> \n    <b> 飞书群</b>\n  </a> \n| <a href=\"asset/siiRL-wechat-group.png\">\n    <img src=\"asset/logo-wechat.png\" alt=\"Wechat Group QR Code\" height=\"15\" /> \n    <b> 微信群</b>\n  </a> \n| <a href=\"README.md\"><b> English</b></a> |\n</p>\n\n**siiRL** 是一个新型的、**完全分布式的强化学习 (RL) 框架**，旨在突破大语言模型 (LLM) 后训练中的扩展性瓶颈，并支持未来的多智能体研究，由**上海创智学院**的研究人员开发。\n\n通过移除主流框架中的中心化数据流控制器，siiRL 实现了**近线性的扩展能力**、**显著的吞吐量提升**，通过DAG模块化的设计获得了**极大的的灵活性**，为基于强化学习的 LLM 开发带来了全新的可能性。\n\n---\n\n## 🚀 亮点\n\n+ **近线性扩展能力**: 多控制器模式通过将控制逻辑和数据管理分布到所有工作节点，消除了中心化瓶颈，从而实现了在数千张 GPU 上的近线性扩展。\n\n+ **业界领先的吞吐量 (SOTA)**: 完全分布式的数据流架构最大限度地减少了通信和 I/O 开销，在数据密集型场景中实现了业界领先的吞吐量。\n\n+ **灵活的 DAG 定义流水线**: 将您的算法逻辑与物理硬件解耦。通过 siiRL，您可以将复杂的 RL 工作流定义为一个简单的有向无环图 (DAG)，从而实现快速、经济且无需编写代码的实验。\n\n+ **跨硬件兼容性**: siiRL 现已正式支持华为昇腾 (Ascend) NPU，为在不同硬件平台上进行训练和推理提供了高性能的替代方案。\n\n+ **经过验证的性能与稳定性**: 在 7B 到 72B 尺寸的模型上进行了广泛的基准测试，siiRL 在各种任务中均表现出卓越的性能。其优势在长上下文和多模态训练等数据密集型工作负载中尤为明显。\n\n---\n\n## 📰 最新动态\n\n* **[2025/11]**: siiRL 现已支持视觉-语言-动作（VLA）模型训练，基于 [SRPO (Self-Referential Policy Optimization for Vision-Language-Action Models)](https://arxiv.org/pdf/2511.15605) 算法，实现了机器人任务的具身强化学习训练。详细使用方法请参考[文档](/docs/examples/embodied_srpo_example.rst)。\n\n* **[2025/09]**: siiRL 现已集成 Megatron 训练后端，并支持MoE模型训练。其性能已在 Qwen3-MoE 模型（30B、235B）上得到验证。\n\n* **[2025/09]**: siiRL通过与华为昇腾、沐曦科技、阿里云等主要厂商合作，现已支持在其GPU 集群上从 32 卡稳定扩展至 1024 卡，线性扩展效率超过 90%。\n\n* **[2025/09]**: siiRL 支持多智能体与环境之间进行多轮交互。\n\n* **[2025/07]**: siiRL 为 LaMAS 新增了 [MARFT](https://arxiv.org/pdf/2504.16129) 支持，可通过 Flex-POMDP 对 LLM 多智能体进行强化学习微调。\n\n* **[2025/07]**: siiRL 现已支持 [CPGD](https://arxiv.org/pdf/2505.12504v1)，这是一种通过正则化大幅度的策略更新来增强 RL 训练稳定性和性能的算法。\n\n* **[2025/07]**: 我们很开心向开源社区发布 siiRL！欢迎查阅我们的[论文](https://arxiv.org/abs/2507.13833)，深入了解其架构和评测。\n\n---\n\n## 💡 架构概览\n\nsiiRL 是一个为大规模集群设计的完全分布式强化学习框架。siiRL 采用多控制器模式，将所有计算和数据流均匀地分派到每个 GPU。siiRL 由三个主要组件构成：DAG Planner，DAG Workers 和 Data Coordinator.\n\n<div align=\"center\">\n  <img src=\"asset/overview.png\" width=\"650px\" alt=\"siiRL 架构概览\">\n  <p><i>图 1. siiRL 架构概览。</i></p>\n</div>\n\nsiiRL 是一个**完全分布式、多控制器的架构**。\n\n关键组件包括：\n* **DAG Planner**: 将用户定义的 DAG 转换为序列化、可供每个DAG Worker执行的流水线。\n* **DAG Workers**: 核心执行单元，每个DAG Worker绑定到单个 GPU，独立运行其分配的任务。\n* **Data Coordinator**: 一组分布式组件（`分布式数据加载器`和`分布式数据缓冲区`），无需中央协调器即可管理从初始加载到中间数据重分配的整个数据生命周期。\n\n### 典型支持的模型与算法\n\n<table style=\"width: 100%; table-layout: auto; border-collapse: collapse;\">\n  <thead align=\"center\" valign=\"bottom\">\n    <tr>\n      <th style=\"min-width: 120px;\">模型</th>\n      <th style=\"min-width: 120px;\">算法</th>\n    </tr>\n  </thead>\n  <tbody valign=\"top\">\n    <tr>\n      <td>\n        <b>Qwen2.5 系列</b>\n        <ul style=\"margin-left: 0; padding-left: 16px;\">\n          <li>Qwen2.5-7B </li>\n          <li>Qwen2.5-72B </li>\n          <li>Qwen2.5-VL-7B </li>\n          <li>Qwen2.5-VL-72B </li>\n        </ul>\n        <b>Qwen3 系列</b>\n        <ul style=\"margin-left: 0; padding-left: 16px;\">\n          <li>Qwen3-1.7B </li>\n          <li>Qwen3-30B </li>\n          <li>Qwen3-235B-A22B (MoE) </li>\n        </ul>\n        <b>VLA 模型</b>\n        <ul style=\"margin-left: 0; padding-left: 16px;\">\n          <li>OpenVLA </li>\n          <li>OpenVLA-OFT </li>\n        </ul>\n      </td>\n      <td>\n        <b>强化学习算法</b>\n        <ul style=\"margin-left: 极: 0; padding-left: 16px;\">\n          <li>GRPO </li>\n          <li>PPO </li>\n          <li>DAPO </li>\n          <li>GSPO </li>\n        </ul>\n      </td>\n    </tr>\n  </tbody>\n</table>\n\n## 🧪 实验评测\n\n我们对 siiRL 的性能和扩展性进行了全面评测，并与业界领先的 RL 框架 verl 进行了比较。实验表明，siiRL 在所有指标上均表现出卓越的性能。\n\n### 端到端吞吐量\n在标准的 PPO 和 GRPO 算法下，siiRL 的吞吐量全面超越了基线系统。特别是在数据密集度更高的 GRPO 算法下，siiRL 通过其完全分布式的架构有效解决了数据瓶颈，实现了高达 **2.62 倍**的性能提升。\n\n<p align=\"center\">\n<img src=\"asset/ppo_performance_comparison.png\" width=\"80%\" alt=\"PPO 算法性能对比\"/>\n<br>\n<em>图 2:  PPO 算法下端到端性能对比</em>\n</p>\n<p align=\"center\">\n<img src=\"asset/grpo_performance_comparison.png\" width=\"80%\" alt=\"GRPO 算法性能对比\"/>\n<br>\n<em>图 3: GRPO 算法下端到端性能对比</em>\n</p>\n\n### 大规模扩展性\nsiiRL 展示了近线性的扩展能力，可平滑扩展至 1024 张 GPU。相比之下，基线框架由于其单点数据瓶颈导致的 OOM (内存不足) 错误，在相同条件下运行失败。在基线系统所能支持的最大批量大小下，siiRL 的性能优势可高达 **7 倍**。\n\n<p align=\"center\">\n<img src=\"asset/scaling_trend_new.png\" width=\"80%\" alt=\"siiRL 扩展性测试\"/>\n<br>\n<em>图 4: siiRL 的扩展性测试</em>\n</p>\n\n<p align=\"center\">\n<img src=\"asset/batch_size_total_throughput_final.png\" width=\"80%\" alt=\"VLM 任务性能对比\"/>\n<br>\n<em>图 5: 在基线系统最大负载下的性能对比</em>\n</p>\n\n### 长上下文性能\n在处理长上下文任务时，数据传输开销成为主要瓶颈。siiRL 的分布式数据流设计使其性能优势随着上下文长度的增加而愈发明显，实现了高达 **2.03 倍**的吞吐量提升，并成功运行了基线系统无法处理的 72B 模型长上下文任务。\n\n<p align=\"center\">\n<img src=\"asset/context_length_comparison_with_oom_label.png\" width=\"80%\" alt=\"长上下文性能对比\"/>\n<br>\n<em>图 6: 长上下文场景下的性能对比</em>\n</p>\n\n### 模型收敛性\n实验证实，siiRL 的性能优化并未以牺牲模型精度为代价。在超参数相同的情况下，siiRL 的奖励和熵收敛曲线与基线系统完全一致，同时将总训练时间**减少了 21%**。\n\n<p align=\"center\">\n<img src=\"asset/reward_and_entropy_comparison_final.png\" width=\"45%\" alt=\"收敛曲线对比\"/>\n<br>\n<em>图 7: 模型收敛曲线对比</em>\n</p>\n\n---\n\n## 📚 相关资源\n\n<a href=\"https://siirl.readthedocs.io/en/latest/index.html\"><b>使用文档</b></a>\n\n- <a href=\"https://siirl.readthedocs.io/en/latest/start/install.html\"><b>安装指南</b></a>\n\n- <a href=\"https://siirl.readthedocs.io/en/latest/start/quickstart.html\"><b>快速入门: 运行 PPO/GRPO</b></a>\n\n---\n\n## 🗓️ 未来计划\n\nsiiRL 仍在积极开发中。我们对未来充满期待，并致力于在两个关键方向上扩展框架的功能：支持真实机器人 VLA 训练和训练推理分离。\n\n### 具身 VLA 训练与真实世界部署\n我们正在扩展视觉-语言-动作（VLA）能力，以支持**真实世界机器人部署**。\n\n### 训练-推理分离架构\n为增强部署灵活性和资源利用率，我们正在开发**解耦的训练-推理架构**。\n\n---\n\n## 🙏 致谢\n\n我们首先要感谢开源 RL 框架 [verl](https://github.com/volcengine/verl)，我们使用它作为评测的主要基线系统。我们特别感谢其分层的 API 设计；我们复用了 verl 中的 `3DParallelWorker` 基类来管理 siiRL 中的系统组件。\n\nsiiRL 的构建也离不开其他优秀的开源项目。我们衷心感谢 PyTorch、Ray、vLLM、vLLM-Ascend 和 SGLang 团队的杰出工作。\n\n我们的工作解决了研究过程中发现的扩展性问题并设计了灵活的工作流设计，并希望 siiRL 能为社区的共同进步做出积极贡献。\n\n---\n\n## 🖋️ 如何引用\n\n如果您在研究中发现 siiRL 对您有帮助，请考虑引用我们的论文。\n\n```bibtex\n@misc{wang2025distflowfullydistributedrl,\n      title={DistFlow: A Fully Distributed RL Framework for Scalable and Efficient LLM Post-Training}, \n      author={Zhixin Wang and Tianyi Zhou and Liming Liu and Ao Li and Jiarui Hu and Dian Yang and Jinlong Hou and Siyuan Feng and Yuan Cheng and Yuan Qi},\n      year={2025},\n      eprint={2507.13833},\n      archivePrefix={arXiv},\n      primaryClass={cs.DC},\n      url={[https://arxiv.org/abs/2507.13833](https://arxiv.org/abs/2507.13833)}, \n}"
  },
  {
    "path": "README.md",
    "content": "\n<div align=\"center\">\n  <img src=\"asset/sii.png\" width=\"100%\"/>\n  <br>\n</div>\n<br>\n\n<h1 align=\"center\">\nsiiRL: Shanghai Innovation Institute RL Framework for Advanced LLMs and Multi-Agent Systems\n</h1>\n\n<p align=\"center\">\n| <a href=\"https://arxiv.org/abs/2507.13833\"><b>📄 Paper</b></a> \n| <a href=\"https://siirl.readthedocs.io/en/latest/index.html\"><b>📚 Documentation</b></a> \n| <a href=\"asset/siiRL-feishu-group.png\">\n    <img src=\"asset/logo-feishu.png\" alt=\"Feishu Group QR Code\" height=\"15\" /> \n    <b> Feishu Group</b>\n  </a> \n| <a href=\"asset/siiRL-wechat-group.png\">\n    <img src=\"asset/logo-wechat.png\" alt=\"Wechat Group QR Code\" height=\"15\" /> \n    <b> Wechat Group</b>\n  </a> \n| <a href=\"README-zh.md\"><b>🇨🇳 中文</b></a> |\n</p>\n\n**siiRL** is a novel, **fully distributed reinforcement learning (RL) framework** designed to break the scaling barriers in LLM post-training. Developed by researchers from **Shanghai Innovation Institute**, siiRL tackles the critical performance bottlenecks that limit current state-of-the-art systems.\n\nBy eliminating the centralized controller common in other frameworks, siiRL delivers **near-linear scalability**, **dramatic throughput gains**, and **unprecedented flexibility** for RL-based LLM development.\n\n---\n\n## 🚀 Highlights\n\n+ **Near-Linear Scalability**: The multi-controller paradigm eliminates central bottlenecks by distributing control logic and data management across all workers, enabling near-linear scalability to thousands of GPUs.\n\n+ **SOTA Throughput**: Fully distributed dataflow architecture minimizes communication and I/O overhead, achieving SOTA throughput in data-intensive scenarios.\n\n+ **Flexible DAG-Defined Pipeline**: Decouple your algorithmic logic from the physical hardware. With siiRL, you can define complex RL workflows as a simple Directed Acyclic Graph (DAG), enabling rapid, cost-effective, and code-free experimentation.\n\n+ **Cross-Hardware Compatibility**: siiRL now officially supports Huawei's Ascend NPUs, providing a high-performance alternative for training and inference on different hardware platforms.\n\n+ **Proven Performance & Stability**: Extensively benchmarked on models from 7B to 72B, siiRL delivering excellent performance across a wide range of tasks. Its advantages are particularly evident in data-intensive workloads such as long-context and multi-modal training.\n\n---\n\n## 📰 News\n* **[2025/11]**: siiRL now supports Vision-Language-Action (VLA) model training with [SRPO (Self-Referential Policy Optimization for Vision-Language-Action Models)](https://arxiv.org/pdf/2511.15605), enabling embodied RL training on robotics tasks. See the [documentation](/docs/examples/embodied_srpo_example.rst) for usage instructions.\n* **[2025/09]**: Added an explanation of the siiRL [code implementation](/docs/code_explained/siiRL-code-explained.md) for interested users and developers. A [Chinese version](https://zhuanlan.zhihu.com/p/1951768778875605883) is also available on Zhihu.\n\n* **[2025/09]**:siiRL now integrates Megatron training backend with support for MoE training. Performance has been validated on Qwen3-MoE models (30B, 235B).\n\n* **[2025/09]**:siiRL now supports stable scaling on GPU clusters from 32 GPUs up to 1024 GPUs, with over 90% linear scalability efficiency, through collaboration with major manufacturers including Huawei Ascend, MetaX, and Alibaba PPU.\n\n* **[2025/09]**:siiRL supports multi-turn interactions among multi-agents with the environment.\n\n* **[2025/07]**:siiRL adds [MARFT](https://arxiv.org/pdf/2504.16129) support for LaMAS, enabling RL fine-tuning of multi-LLM agents via Flex-POMDP.\n\n* **[2025/07]**: siiRL now supports [CPGD](https://arxiv.org/pdf/2505.12504v1), a novel algorithm that enhances RL training stability and performance by regularizing large policy updates.\n\n* **[2025/07]**: We are excited to release siiRL to the open-source community! Check out our [paper](https://arxiv.org/abs/2507.13833) for a deep dive into the architecture and evaluation.\n\n---\n\n## 💡 Architecture Overview\n\nsiiRL is a fully distributed RL framework designed for scalability on large-scale clusters. siiRL employs a multi-controller paradigm that uniformly dispatches all computational and data flow across each GPU. siiRL consists of three main components: a DAG Planner, DAG Workers, and a Data Coordinator. \n\n<div align=\"center\">\n  <img src=\"asset/overview.png\" width=\"650px\" alt=\"Overview of siiRL\">\n  <p><i>Figure 1. Overview of siiRL.</i></p>\n</div>\n\nsiiRL solves this problem with a **fully distributed, multi-controller architecture**.\n\nKey components include:\n* **DAG Planner**: Translates a user-defined logical workflow (DAG) into a serialized, executable pipeline for each worker.\n* **DAG Workers**: The core execution units, with each worker bound to a single GPU, running its assigned tasks independently.\n* **Data Coordinator**: A set of distributed components (`Distributed Dataloader` and `Distributed Databuffer`) that manage the entire data lifecycle, from initial loading to intermediate data redistribution, without a central coordinator.\n\n### Typical Supported Models & Algorithms\n\n<table style=\"width: 100%; table-layout: auto; border-collapse: collapse;\">\n  <thead align=\"center\" valign=\"bottom\">\n    <tr>\n      <th style=\"min-width: 120px;\">Models</th>\n      <th style=\"min-width: 120px;\">Algorithms</th>\n    </tr>\n  </thead>\n  <tbody valign=\"top\">\n    <tr>\n      <td>\n        <b>Qwen2.5 Series</b>\n        <ul style=\"margin-left: 0; padding-left: 16px;\">\n          <li>Qwen2.5-7B </li>\n          <li>Qwen2.5-72B </li>\n          <li>Qwen2.5-VL-7B </li>\n          <li>Qwen2.5-VL-72B </li>\n        </ul>\n        <b>Qwen3 Series</b>\n        <ul style=\"margin-left: 0; padding-left: 16px;\">\n          <li>Qwen3-1.7B </li>\n          <li>Qwen3-30B </li>\n          <li>Qwen3-235B-A22B (MoE) </li>\n        </ul>\n        <b>VLA Models</b>\n        <ul style=\"margin-left: 0; padding-left: 16px;\">\n          <li>OpenVLA </li>\n          <li>OpenVLA-OFT </li>\n        </ul>\n      </td>\n      <td>\n        <b>Reinforcement Learning</b>\n        <ul style=\"margin-left: 0; padding-left: 16px;\">\n          <li>GRPO </li>\n          <li>PPO </li>\n          <li>DAPO </li>\n          <li>GSPO </li>\n        </ul>\n      </td>\n    </tr>\n  </tbody>\n</table>\n\n##  🧪 Experiment\n\nWe conducted a comprehensive evaluation of siiRL's performance and scalability across various scenarios, comparing it with the SOTA RL framework, verl. The experiments demonstrate that siiRL exhibits outstanding performance across all metrics.\n\n### End-to-End Throughput\nUnder the standard PPO and GRPO algorithms, siiRL's throughput comprehensively surpasses the baseline. Particularly with the more data-intensive GRPO algorithm, siiRL effectively resolves data bottlenecks through its fully distributed architecture, achieving up to a 2.62x performance improvement.\n\n<p align=\"center\">\n<img src=\"asset/ppo_performance_comparison.png\" width=\"80%\" alt=\"PPO Algorithm Performance Comparison\"/>\n<br>\n<em>Figure 2: End-to-end performance comparison using the PPO algorithm </em>\n</p>\n<p align=\"center\">\n<img src=\"asset/grpo_performance_comparison.png\" width=\"80%\" alt=\"GRPO Algorithm Performance Comparison\"/>\n<br>\n<em>Figure 3: End-to-end performance comparison using the GRPO algorithm </em>\n</p>\n\n### Large-Scale Scalability\nsiiRL demonstrates near-linear scalability, smoothly extending up to 1024 GPUs. In contrast, the baseline framework fails under identical conditions due to OOM errors caused by its single-point data bottleneck. At the maximum batch size the baseline can support, siiRL's performance advantage can be as high as 7x.\n\n<p align=\"center\">\n<img src=\"asset/scaling_trend_new.png\" width=\"80%\" alt=\"siiRL Scalability Test\"/>\n<br>\n<em>Figure 4: Near-linear scalability of siiRL on VLM models </em>\n</p>\n\n<p align=\"center\">\n<img src=\"asset/batch_size_total_throughput_final.png\" width=\"80%\" alt=\"VLM Task Performance Comparison\"/>\n<br>\n<em>Figure 5: VLM task performance comparison under the baseline's maximum load </em>\n</p>\n\n### Long-Context Performance\nWhen processing long-context tasks, data transfer overhead becomes a major bottleneck. siiRL's distributed dataflow design allows its performance advantage to become more pronounced as context length increases, achieving up to a 2.03x throughput improvement and successfully running a 72B model long-context task that the baseline could not handle.\n\n<p align=\"center\">\n<img src=\"asset/context_length_comparison_with_oom_label.png\" width=\"80%\" alt=\"Long-Context Performance Comparison\"/>\n<br>\n<em>Figure 6: Performance comparison in long-context scenarios </em>\n</p>\n\n### Model Convergence\nExperiments confirm that siiRL's performance optimizations do not come at the cost of model accuracy. With identical hyperparameters, siiRL's reward and entropy convergence curves are identical to the baseline's, while reducing the total training time by 21%.\n\n<p align=\"center\">\n<img src=\"asset/reward_and_entropy_comparison_final.png\" width=\"45%\" alt=\"Convergence Curve Comparison\"/>\n<br>\n<em>Figure 7: Model convergence curve comparison </em>\n</p>\n\n---\n\n## 📚 Resources\n\n<a href=\"https://siirl.readthedocs.io/en/latest/index.html\"><b>Documentation</b></a>\n\n- <a href=\"https://siirl.readthedocs.io/en/latest/start/install.html\"><b>Installation</b></a>\n\n- <a href=\"https://siirl.readthedocs.io/en/latest/start/quickstart.html\"><b>Quickstart: Running PPO/GRPO</b></a>\n\n---\n\n## 🗓️ Future Plans\n\nsiiRL is under active development. We are excited about the future and are focused on extending the framework's capabilities in two key directions: support training-tnference separation and real-robot VLA training.\n\n###  Training-Inference Separation Architecture\nTo enhance deployment flexibility and resource utilization, we are developing a **decoupled training-inference architecture**.\n\n###  Embodied VLA Training & Real-World Deployment\nWe are expanding our Vision-Language-Action (VLA) capabilities to support **real-world robotics deployment**.\n\n\nWe welcome community contributions! Please see our [Contributing Guide](CONTRIBUTING.md) to get started.\n\n---\n\n## 🙏 Acknowledgement\n\nWe would first like to thank the open-source RL framework [verl](https://github.com/volcengine/verl), which we used as a primary baseline for our evaluations. We would like to directly acknowledge its hierarchical API design; we reuse the 3DParallelWorker base class from verl to manage system components in siiRL.\n\nsiiRL is also built upon a foundation of other great open-source projects. We would like to thank the teams behind PyTorch, Ray, vLLM, vLLM-Ascend and SGLang for their incredible work.\n\nOur work aims to address the scalability challenges identified during our research, and we hope siiRL can contribute positively to the community's collective progress.\n\n---\n\n## 🖋️ Citation\n\nIf you find siiRL useful in your research, please consider citing our paper.\n\n```bibtex\n@misc{wang2025distflowfullydistributedrl,\n      title={DistFlow: A Fully Distributed RL Framework for Scalable and Efficient LLM Post-Training}, \n      author={Zhixin Wang and Tianyi Zhou and Liming Liu and Ao Li and Jiarui Hu and Dian Yang and Jinlong Hou and Siyuan Feng and Yuan Cheng and Yuan Qi},\n      year={2025},\n      eprint={2507.13833},\n      archivePrefix={arXiv},\n      primaryClass={cs.DC},\n      url={https://arxiv.org/abs/2507.13833}, \n}\n```\n\n"
  },
  {
    "path": "docker/Dockerfile.cu124",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nFROM nvcr.io/nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04\n\nLABEL maintainer=\"SII AI Infra Team\"\n\n# base environment\nRUN apt update \\\n    && apt install -y rdma-core ibverbs-providers ibverbs-utils   \\\n    && apt install -y python3 python3-pip \\\n    && ln -sf /usr/bin/python3 /usr/bin/python  \\\n    && python -m pip install -U pip \\\n    && pip install -U setuptools wheel\n\n# dev tools\nRUN apt install -y git cmake ninja-build vim\n\n# python packages\nRUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124   \\\n    && pip install flashinfer-python -i https://flashinfer.ai/whl/cu124/torch2.6/   \\\n    && pip install flash-attn==2.7.3 --no-build-isolation   \\\n    && pip install vllm==0.8.5.post1    \\\n    && pip install accelerate codetiming datasets dill hydra-core pandas wandb loguru tensorboard qwen_vl_utils \\\n    && pip install 'ray[default]>=2.47.1' \\\n    && pip install opentelemetry-exporter-prometheus==0.47b0 \\\n    && pip install mbridge \\\n    && pip install numpy==1.26.4 \n\n# apex\nRUN git clone https://github.com/NVIDIA/apex.git \\\n    && cd apex \\\n    && MAX_JOBS=16 pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" ./    \\\n    && cd .. && rm -rf apex\n\n# optional: sglang\nRUN pip install 'sglang[all]==0.4.6.post5'    \\\n    && pip install xgrammar==0.1.18\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.3\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2\n"
  },
  {
    "path": "docker/Dockerfile.cu126",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nFROM nvcr.io/nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04\n\nLABEL maintainer=\"SII AI Infra Team\"\n\n# base environment\nRUN apt update \\\n    && apt install -y rdma-core ibverbs-providers ibverbs-utils libnuma-dev  \\\n    && apt install -y python3 python3-pip \\\n    && ln -sf /usr/bin/python3 /usr/bin/python  \\\n    && python -m pip install -U pip \\\n    && pip install -U setuptools wheel\n\n# dev tools\nRUN apt install -y git cmake ninja-build vim\n\n# python packages\nRUN pip install torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1   \\\n    && pip install flash-attn==2.8.2 --no-build-isolation   \\\n    && pip install vllm==0.10.0    \\\n    && pip install accelerate codetiming datasets dill hydra-core pandas wandb loguru tensorboard qwen_vl_utils \\\n    && pip install mbridge \\\n    && pip install numpy==1.26.4 \n\n# apex\nRUN git clone https://github.com/NVIDIA/apex.git \\\n    && cd apex \\\n    && MAX_JOBS=16 pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" ./    \\\n    && cd .. && rm -rf apex\n\n# optional: sglang\nRUN pip install 'sglang[all]==0.4.10.post2' \\\n    && pip install outlines==1.2.3 xgrammar==0.1.21\n\n# Install TransformerEngine\nRUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.3\n\n# Install Megatron-LM\nRUN pip3 install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.2"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = .\nBUILDDIR      = build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# For the full list of built-in configuration values, see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Project information -----------------------------------------------------\n# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information\n\nproject = \"siiRL\"\ncopyright = \"2025, SII AI Infra Team\"\nauthor = \"SII AI Infra Team\"\n\n# -- General configuration ---------------------------------------------------\n# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration\n\nextensions = [\n    \"myst_parser\",\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.autosummary\",\n    \"sphinx.ext.autosectionlabel\",\n    \"sphinx.ext.napoleon\",\n    \"sphinx.ext.viewcode\",\n]\n# Use Google style docstrings instead of NumPy docstrings.\nnapoleon_google_docstring = True\nnapoleon_numpy_docstring = False\n\n# Make autosectionlabel use document name as prefix to avoid duplicate label warnings\nautosectionlabel_prefix_document = True\n\n# The suffix(es) of source filenames.\n# You can specify multiple suffix as a list of string:\nsource_suffix = {\n    \".rst\": \"restructuredtext\",\n    \".md\": \"markdown\",\n}\n\ntemplates_path = [\"_templates\"]\n\n# The language for content autogenerated by Sphinx. Refer to documentation\n# for a list of supported languages.\n#\n# This is also used if you do content translation via gettext catalogs.\n# Usually you set \"language\" from the command line for these cases.\nlanguage = \"en\"\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = [\"_build\", \"Thumbs.db\", \".DS_Store\", \"plan_*.md\"]\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = \"sphinx_rtd_theme\"\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = [\"_static\"]\n"
  },
  {
    "path": "docs/examples/config.rst",
    "content": ".. _config-explain-page:\n\n===================\nConfiguration Guide\n===================\n\nsiiRL uses Hydra-based configuration management with dataclass parameters. All configuration parameters are defined in the ``siirl/params/`` directory and can be set via command-line arguments.\n\nConfiguration Structure\n-----------------------\n\nParameters are organized into the following modules:\n\n- ``DataArguments``: Data-related parameters (``siirl/params/data_args.py``)\n- ``ActorRolloutRefArguments``: Actor, Rollout, and Reference model parameters (``siirl/params/model_args.py``)\n- ``CriticArguments``: Critic model parameters (``siirl/params/model_args.py``)\n- ``RewardModelArguments``: Reward model parameters (``siirl/params/model_args.py``)\n- ``AlgorithmArguments``: RL algorithm parameters (``siirl/params/model_args.py``)\n- ``TrainingArguments``: Training configuration (``siirl/params/training_args.py``)\n- ``DAGArguments``: DAG workflow parameters (``siirl/params/dag_args.py``)\n- ``ProfilerArguments``: Profiling parameters (``siirl/params/profiler_args.py``)\n\nAll parameters are combined into the ``SiiRLArguments`` class.\n\nUsage\n-----\n\nParameters are set via command-line arguments using dot notation:\n\n.. code-block:: bash\n\n   python -m siirl.main_dag \\\n     data.train_files=/path/to/train.parquet \\\n     data.train_batch_size=512 \\\n     actor_rollout_ref.model.path=/path/to/model \\\n     algorithm.adv_estimator=grpo \\\n     trainer.total_epochs=30\n\nData Parameters\n---------------\n\nLocation: ``siirl/params/data_args.py``\n\n.. code-block:: bash\n\n   data.tokenizer=null\n   data.train_files=/path/to/train.parquet\n   data.val_files=/path/to/val.parquet\n   data.prompt_key=prompt\n   data.max_prompt_length=512\n   data.max_response_length=512\n   data.train_batch_size=1024\n   data.return_raw_input_ids=False\n   data.return_raw_chat=False\n   data.return_full_prompt=False\n   data.shuffle=True\n   data.filter_overlong_prompts=False\n   data.filter_overlong_prompts_workers=1\n   data.truncation=error\n   data.image_key=images\n   data.trust_remote_code=True\n\n**Key Parameters:**\n\n- ``data.train_files``: Training data file path (Parquet format, can be list or single file)\n- ``data.val_files``: Validation data file path\n- ``data.prompt_key``: Field name for prompt in dataset (default: \"prompt\")\n- ``data.max_prompt_length``: Maximum prompt length (left-padded)\n- ``data.max_response_length``: Maximum response length for rollout generation\n- ``data.train_batch_size``: Training batch size per iteration\n- ``data.return_raw_input_ids``: Return original input_ids without chat template (for different RM chat templates)\n- ``data.shuffle``: Whether to shuffle data\n- ``data.truncation``: Truncation strategy (\"error\", \"left\", \"right\", \"middle\")\n- ``data.trust_remote_code``: Allow remote code execution for tokenizers\n\nCustom Dataset\n~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n   data.custom_cls.path=/path/to/custom_dataset.py\n   data.custom_cls.name=MyDatasetClass\n\n- ``data.custom_cls.path``: Path to custom dataset class file\n- ``data.custom_cls.name``: Name of the dataset class\n\nActor/Rollout/Reference Model\n------------------------------\n\nLocation: ``siirl/params/model_args.py``\n\nModel Configuration\n~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n   actor_rollout_ref.hybrid_engine=True\n   actor_rollout_ref.model.path=/path/to/model\n   actor_rollout_ref.model.external_lib=null\n   actor_rollout_ref.model.enable_gradient_checkpointing=False\n   actor_rollout_ref.model.enable_activation_offload=False\n   actor_rollout_ref.model.trust_remote_code=False\n   actor_rollout_ref.model.use_remove_padding=False\n\n- ``actor_rollout_ref.model.path``: Huggingface model path (local or HDFS)\n- ``actor_rollout_ref.model.external_lib``: Additional Python packages to import\n- ``actor_rollout_ref.model.enable_gradient_checkpointing``: Enable gradient checkpointing\n- ``actor_rollout_ref.model.enable_activation_offload``: Enable activation offloading\n- ``actor_rollout_ref.model.trust_remote_code``: Allow remote code model loading\n- ``actor_rollout_ref.model.use_remove_padding``: Remove padding tokens for efficiency\n\nActor Configuration\n~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n   actor_rollout_ref.actor.strategy=fsdp\n   actor_rollout_ref.actor.ppo_mini_batch_size=256\n   actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8\n   actor_rollout_ref.actor.grad_clip=1.0\n   actor_rollout_ref.actor.clip_ratio=0.2\n   actor_rollout_ref.actor.entropy_coeff=0.0\n   actor_rollout_ref.actor.use_kl_loss=False\n   actor_rollout_ref.actor.kl_loss_coef=0.001\n   actor_rollout_ref.actor.ppo_epochs=1\n   actor_rollout_ref.actor.optim.lr=1e-6\n\n- ``actor.strategy``: Backend strategy (\"fsdp\" or \"megatron\")\n- ``actor.ppo_mini_batch_size``: Mini-batch size for PPO updates (global across GPUs)\n- ``actor.ppo_micro_batch_size_per_gpu``: Micro-batch size per GPU (gradient accumulation)\n- ``actor.grad_clip``: Gradient clipping threshold\n- ``actor.clip_ratio``: PPO clip ratio\n- ``actor.use_kl_loss``: Enable KL loss in actor\n- ``actor.kl_loss_coef``: KL loss coefficient (for GRPO)\n- ``actor.optim.lr``: Learning rate\n\nReference Model\n~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n   actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16\n   actor_rollout_ref.ref.fsdp_config.param_offload=False\n\n- ``ref.log_prob_micro_batch_size_per_gpu``: Micro-batch size for reference log prob computation\n- ``ref.fsdp_config.param_offload``: Enable parameter offloading (recommended for models > 7B)\n\nRollout Configuration\n~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n   actor_rollout_ref.rollout.name=vllm\n   actor_rollout_ref.rollout.temperature=1.0\n   actor_rollout_ref.rollout.top_k=-1\n   actor_rollout_ref.rollout.top_p=1.0\n   actor_rollout_ref.rollout.tensor_model_parallel_size=2\n   actor_rollout_ref.rollout.gpu_memory_utilization=0.5\n   actor_rollout_ref.rollout.n=8\n\n- ``rollout.name``: Rollout backend (\"vllm\", \"sglang\", \"hf\")\n- ``rollout.temperature``: Sampling temperature\n- ``rollout.top_k``: Top-k sampling (-1 for vLLM, 0 for HF)\n- ``rollout.top_p``: Top-p sampling\n- ``rollout.tensor_model_parallel_size``: Tensor parallelism size (vLLM only)\n- ``rollout.gpu_memory_utilization``: GPU memory fraction for vLLM\n- ``rollout.n``: Number of responses per prompt (>1 for GRPO/RLOO)\n\nCritic Model\n------------\n\nLocation: ``siirl/params/model_args.py``\n\n.. code-block:: bash\n\n   critic.enable=True\n   critic.model.path=/path/to/critic_model\n   critic.ppo_mini_batch_size=256\n   critic.ppo_micro_batch_size_per_gpu=8\n   critic.optim.lr=1e-5\n\nMost parameters are similar to Actor configuration.\n\nReward Model\n------------\n\nLocation: ``siirl/params/model_args.py``\n\n.. code-block:: bash\n\n   reward_model.enable=False\n   reward_model.model.path=/path/to/reward_model\n   reward_model.model.input_tokenizer=null\n   reward_model.micro_batch_size_per_gpu=16\n   reward_model.reward_manager=naive\n\n- ``reward_model.enable``: Enable reward model (False = use only custom reward functions)\n- ``reward_model.model.input_tokenizer``: Input tokenizer path (if different from policy)\n- ``reward_model.reward_manager``: Reward manager type (\"naive\", \"batch\", \"parallel\", \"dapo\", \"embodied\")\n\nCustom Reward Function\n~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n   custom_reward_function.path=/path/to/my_reward.py\n   custom_reward_function.name=compute_score\n\n- ``custom_reward_function.path``: Path to custom reward function file\n- ``custom_reward_function.name``: Function name (default: \"compute_score\")\n\nSee :doc:`/user_interface/reward_interface` for details.\n\nAlgorithm Parameters\n--------------------\n\nLocation: ``siirl/params/model_args.py``\n\n.. code-block:: bash\n\n   algorithm.gamma=1.0\n   algorithm.lam=1.0\n   algorithm.adv_estimator=grpo\n   algorithm.use_kl_in_reward=False\n   algorithm.kl_penalty=kl\n   algorithm.kl_ctrl.type=fixed\n   algorithm.kl_ctrl.kl_coef=0.005\n   algorithm.workflow_type=DEFAULT\n\n- ``algorithm.gamma``: Discount factor\n- ``algorithm.lam``: GAE lambda (bias-variance tradeoff)\n- ``algorithm.adv_estimator``: Advantage estimator (\"gae\", \"grpo\", \"cpgd\", \"gspo\", \"rloo\")\n- ``algorithm.use_kl_in_reward``: Enable KL penalty in reward\n- ``algorithm.kl_penalty``: KL divergence calculation method (\"kl\", \"abs\", \"mse\", \"low_var_kl\", \"full\")\n- ``algorithm.workflow_type``: Workflow type (\"DEFAULT\", \"DAPO\", \"EMBODIED\")\n\nTraining Parameters\n-------------------\n\nLocation: ``siirl/params/training_args.py``\n\n.. code-block:: bash\n\n   trainer.total_epochs=30\n   trainer.project_name=siirl_examples\n   trainer.experiment_name=gsm8k\n   trainer.logger=['console', 'wandb']\n   trainer.nnodes=1\n   trainer.n_gpus_per_node=8\n   trainer.save_freq=10\n   trainer.val_before_train=True\n   trainer.test_freq=2\n\n- ``trainer.total_epochs``: Number of training epochs\n- ``trainer.project_name``: Project name (for logging)\n- ``trainer.experiment_name``: Experiment name (for logging)\n- ``trainer.logger``: Logger types ([\"console\", \"wandb\", \"tensorboard\", \"mlflow\"])\n- ``trainer.nnodes``: Number of nodes\n- ``trainer.n_gpus_per_node``: Number of GPUs per node\n- ``trainer.save_freq``: Checkpoint saving frequency (by iteration)\n- ``trainer.val_before_train``: Run validation before training\n- ``trainer.test_freq``: Validation frequency (by iteration)\n\nDAG Parameters\n--------------\n\nLocation: ``siirl/params/dag_args.py``\n\n.. code-block:: bash\n\n   dag.custom_pipeline_fn=null\n\n- ``dag.custom_pipeline_fn``: Custom pipeline function path (e.g., \"module:function\")\n\nSee :doc:`/user_interface/pipeline_interface` for custom pipeline details.\n\nComplete Example\n----------------\n\nGRPO Training\n~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n   python -m siirl.main_dag \\\n     algorithm.adv_estimator=grpo \\\n     algorithm.workflow_type=DEFAULT \\\n     data.train_files=/path/to/gsm8k/train.parquet \\\n     data.train_batch_size=512 \\\n     data.max_prompt_length=2048 \\\n     data.max_response_length=4096 \\\n     actor_rollout_ref.model.path=/path/to/model \\\n     actor_rollout_ref.actor.optim.lr=1e-6 \\\n     actor_rollout_ref.actor.ppo_mini_batch_size=256 \\\n     actor_rollout_ref.rollout.name=vllm \\\n     actor_rollout_ref.rollout.tensor_model_parallel_size=2 \\\n     actor_rollout_ref.rollout.n=8 \\\n     custom_reward_function.path=siirl/user_interface/rewards_interface/custom_gsm8k_reward.py \\\n     custom_reward_function.name=compute_score \\\n     trainer.total_epochs=30 \\\n     trainer.n_gpus_per_node=8 \\\n     trainer.save_freq=10\n\nPPO Training\n~~~~~~~~~~~~\n\n.. code-block:: bash\n\n   python -m siirl.main_dag \\\n     algorithm.adv_estimator=gae \\\n     critic.enable=True \\\n     data.train_files=/path/to/data.parquet \\\n     actor_rollout_ref.model.path=/path/to/model \\\n     actor_rollout_ref.actor.optim.lr=1e-6 \\\n     actor_rollout_ref.rollout.name=vllm \\\n     critic.optim.lr=1e-5 \\\n     trainer.total_epochs=30\n\nDAPO Training\n~~~~~~~~~~~~~\n\n.. code-block:: bash\n\n   python -m siirl.main_dag \\\n     algorithm.workflow_type=DAPO \\\n     algorithm.adv_estimator=grpo \\\n     algorithm.filter_groups.enable=True \\\n     algorithm.filter_groups.metric=seq_final_reward \\\n     data.train_files=/path/to/data.parquet \\\n     actor_rollout_ref.model.path=/path/to/model \\\n     trainer.total_epochs=30\n\nParameter Reference\n-------------------\n\nFor the complete parameter definitions, see:\n\n- ``siirl/params/data_args.py`` - Data parameters\n- ``siirl/params/model_args.py`` - Model, algorithm parameters\n- ``siirl/params/training_args.py`` - Training parameters\n- ``siirl/params/dag_args.py`` - DAG workflow parameters\n- ``siirl/params/profiler_args.py`` - Profiler parameters\n\n\n"
  },
  {
    "path": "docs/examples/cpgd_example.rst",
    "content": "DeepScaleR Example with CPGD\n==============================\n\nIntroduction\n------------\n\nThis example demonstrates how to fine-tune a Large Language Model for advanced mathematical reasoning on the **DeepScaleR** dataset using **Clipped Policy Gradient Optimization with Policy Drift (CPGD)**, a novel reinforcement learning algorithm designed for enhanced training stability.\n\n**Paper:** `CPGD: Toward Stable Rule-based Reinforcement Learning for Language Models <https://arxiv.org/abs/2505.12504>`__\n\n**Dataset:** https://huggingface.co/datasets/agentica-org/DeepScaleR-Preview-Dataset\n\nWhile algorithms like PPO and GRPO are powerful, they can sometimes suffer from instability due to their reliance on importance-sampling ratios in the loss function. CPGD is proposed to mitigate these issues by providing a more stable policy update mechanism, making it a robust choice for complex reasoning tasks.\n\nCPGD Algorithm Overview\n-----------------------\n\nCPGD enhances training stability by making two key modifications to the standard policy gradient approach:\n\n1.  **Clipped Policy Gradient Objective**: Instead of directly using the policy ratio in the loss (which can cause high variance), CPGD uses a policy gradient objective. It then applies a clipping mechanism to the *logarithm* of the policy ratio. This prevents excessive policy updates when the ratio becomes too large, effectively keeping the optimization within a trusted region.\n2.  **Policy Drift Regularization**: CPGD introduces a *policy drift* term, which is a KL divergence penalty between the current policy and the old policy from the start of the training iteration. This acts as a dynamic regularizer, pulling the policy back if it strays too far, too quickly, thus preventing training collapse.\n\nTogether, these features allow CPGD to achieve consistent performance improvements while avoiding the instability often seen in other RL algorithms.\n\nStep 1: Prepare the Dataset\n---------------------------\n\nThe data preparation process is identical to other examples using this dataset. First, preprocess the DeepScaleR dataset into the required Parquet format.\n\n.. code:: bash\n\n   cd examples/data_preprocess\n   python3 deepscaler.py --local_dir ~/data/deepscaler\n\nThis command downloads, processes, and saves the training and testing sets in the `~/data/deepscaler` directory.\n\nStep 2: Download the Pre-trained Model\n--------------------------------------\n\nYou need a base model to start the CPGD training. In this example, we use `Qwen2.5-7B-Instruct`.\n\n- **Recommended: Download via CLI:** Use a tool like `huggingface-cli` to download the model to a local directory.\n\n  .. code:: bash\n\n     huggingface-cli download Qwen/Qwen2.5-7B-Instruct --local-dir ~/data/models/Qwen2.5-7B-Instruct\n\n- **Automatic Download:** You can also specify the model name directly in the `actor_rollout_ref.model.path` field of the run script, and the framework will download it automatically.\n\nStep 3: Perform CPGD Training\n-----------------------------\n\nWith the data and model ready, you can now launch the training job using the CPGD algorithm.\n\n**Reward Function**\n\nFor this task, we use the same rule-based reward function as in the PPO/GRPO examples. The framework's default reward mechanism performs an exact match on the final answer within the `\\\\boxed{...}` block. A correct answer receives a positive reward, and an incorrect one receives zero.\n\n**Training Script**\n\nBelow is a complete training script from `examples/cpgd_trainer/run_qwen2_5-7b.sh`. It is configured to use the CPGD algorithm (`algorithm.adv_estimator=cpgd`). Note the presence of CPGD-specific parameters like `actor_rollout_ref.actor.policy_drift_coeff` and `algorithm.weight_factor_in_cpgd`.\n\n.. literalinclude:: ../../examples/cpgd_trainer/run_qwen2_5-7b.sh\n   :language: bash\n   :caption: examples/cpgd_trainer/run_qwen2_5-7b.sh\n"
  },
  {
    "path": "docs/examples/deepscaler_example.rst",
    "content": "DeepScaleR Example with PPO\n=============================\n\nIntroduction\n------------\n\nThis example demonstrates how to fine-tune a Large Language Model for advanced mathematical reasoning using the **DeepScaleR** dataset.\n\n**Paper:** https://pretty-radio-b75.notion.site/DeepScaleR-Surpassing-O1-Preview-with-a-1-5B-Model-by-Scaling-RL-19681902c1468005bed8ca303013a4e2.\n\n**Dataset:** https://huggingface.co/datasets/agentica-org/DeepScaleR-Preview-Dataset\n\nThe core idea is to leverage Reinforcement Learning (RL), specifically Proximal Policy Optimization (PPO), to teach the model not just to find the correct answer, but to follow a logical, step-by-step reasoning process. This is achieved by rewarding the model based on the correctness of its final answer, which is extracted from a structured output.\n\nDataset Overview\n----------------\n\nThe DeepScaleR dataset consists of challenging mathematical problems. Each sample includes a question (`problem`), a detailed reasoning path (`solution`), and a final answer enclosed in a `\\\\boxed{}` block (`answer`).\n\n**An example from DeepScaleR:**\n\n**Prompt:**\n   \"Let $a_n=6^{n}+8^{n}$. Determine the remainder upon dividing $a_ {83}$ by $49$.\"\n\n**Solution:**\n   \"$6^{83} + 8^{83} = (6+8)(6^{82}-6^{81}8+\\\\ldots-8^{81}6+8^{82})$\\n Becuase $7|(6+8)$, we only consider $6^{82}-6^{81}8+\\\\ldots-8^{81}6+8^{82} \\\\pmod{7}$\\n$6^{82}-6^{81}8+\\\\ldots-8^{81}6+8^{82} \\\\equiv (-1)^{82} - (-1)^{81}+ \\\\ldots - (-1)^1 + 1 = 83 \\\\equiv 6 \\\\pmod{7}$\\n$6^{83} + 8^{83} \\\\equiv 14 \\\\cdot 6 \\\\equiv \\\\boxed{035} \\\\pmod{49}$\"\n\n**Answer:**\n   `35`\n\nStep 1: Prepare the Dataset\n---------------------------\n\nFirst, preprocess the DeepScaleR dataset into the required Parquet format. Our framework includes a script for this purpose.\n\n.. code:: bash\n\n   cd examples/data_preprocess\n   python3 deepscaler.py --local_dir ~/data/deepscaler\n\nThis will download the dataset from Hugging Face, process it, and save `train.parquet` and `test.parquet` files in the `~/data/deepscaler` directory.\n\nStep 2: Download the Pre-trained Model\n--------------------------------------\n\nYou need a base model to start the PPO training. In this example, we use `Qwen2.5-7B-Instruct`. There are several ways to make the model available to the trainer:\n\n- **Recommended: Download via CLI:** Use tools like `huggingface-cli` or `modelscope` to download the model to a local directory. This gives you more control.\n\n  .. code:: bash\n\n     # For Hugging Face\n     huggingface-cli download Qwen/Qwen2.5-7B-Instruct --local-dir ~/data/models/Qwen2.5-7B-Instruct --local-dir-use-symlinks False\n     \n     # For ModelScope\n     modelscope download Qwen/Qwen2.5-7B-Instruct --local_dir ~/data/models/Qwen2.5-7B-Instruct\n\n- **Automatic Download:** You can also specify the Hugging Face model name (e.g., `Qwen/Qwen2.5-7B-Instruct`) directly in the `actor_rollout_ref.model.path` and `critic.model.path` fields of your run script. The framework will attempt to download it automatically on the first run.\n\nStep 3: Perform PPO Training\n----------------------------\n\nWith the data and model ready, you can now launch the PPO training job.\n\n**Reward Function**\n\nFor this task, we use a simple but effective rule-based reward function. The framework's default reward mechanism will be used, which performs an exact match between the model's generated answer and the `ground_truth` from the dataset.\n- The model is prompted to provide its final answer inside a `\\\\boxed{...}` block.\n- The reward function checks if the content inside the generated `\\\\boxed{}` matches the ground truth answer.\n- A correct match receives a positive reward (e.g., 1.0), while an incorrect match or a malformed response receives zero reward.\n\n**Training Script**\n\nBelow is a complete training script based on `examples/ppo_trainer/run_qwen3-8b.sh`. It is configured for a single-node, multi-GPU setup. You should adapt paths like `HOME` to your environment.\n\n.. literalinclude:: ../../examples/ppo_trainer/run_qwen3-8b.sh\n   :language: bash\n   :caption: examples/ppo_trainer/run_qwen2_5-7b.sh\n"
  },
  {
    "path": "docs/examples/embodied_srpo_example.rst",
    "content": "Embodied SRPO Training\n======================\n\nIntroduction\n------------\n\nThis guide explains how to perform Embodied AI training using the SRPO algorithm with OpenVLA-oft models on tasks such as LIBERO. Embodied AI training involves an agent interacting with an environment, where the rewards are often based on task success.\n\nThis example demonstrates how to perform RL training on an `OpenVLA-oft-7B` model using the SRPO algorithm on the `libero_long` benchmark.\n\nStep 1: Prepare the Environment\n-------------------------------\n\nYou should use the provided Docker image for Embodied AI training, which contains all necessary dependencies including EGL support for rendering.\n\n**Docker Image**: ``siiai/siirl-vla:libero-egl-cu12.6`` (Available at `Docker Hub <https://hub.docker.com/r/siiai/siirl-vla>`_)\n\nEnsure you have the necessary environment variables set. This includes the path to the `siiRL` repository and any other dependencies.\n\n.. code:: bash\n\n   export SIIRL_DIR=\"/path/to/siiRL\"\n   export VJEPA2_DIR=\"$HOME/code/vjepa2\"  # V-JEPA 2 code repository (https://github.com/facebookresearch/vjepa2)\n   export PYTHONPATH=\"$SIIRL_DIR:/path/to/LIBERO:$VJEPA2_DIR:$PYTHONPATH\"\n\nStep 2: Prepare the Models\n--------------------------\n\nYou need the following models:\n\n1.  **SFT Model**: A Supervised Fine-Tuned (SFT) OpenVLA-oft model. You should select the model that corresponds to your specific task. For example, if you are training on `libero_long`, you should use the `Sylvest/OpenVLA-AC-PD-1traj-libero-long` model.\n\n    Here are the recommended Hugging Face models from the `Sylvest collection <https://huggingface.co/collections/Sylvest/srpo>`_:\n\n    - `Sylvest/OpenVLA-AC-PD-1traj-libero-object` (for `libero_object`)\n    - `Sylvest/OpenVLA-AC-PD-1traj-libero-spatial` (for `libero_spatial`)\n    - `Sylvest/OpenVLA-AC-PD-1traj-libero-goal` (for `libero_goal`)\n    - `Sylvest/OpenVLA-AC-PD-1traj-libero-long` (for `libero_long`)\n\n2.  **Visual Encoder**: A visual encoder model V-JEPA is **required** for processing visual observations.\n\n    - First, clone the V-JEPA 2 code repository from GitHub (`facebookresearch/vjepa2 <https://github.com/facebookresearch/vjepa2>`_):\n    \n      .. code:: bash\n\n         git clone https://github.com/facebookresearch/vjepa2.git $HOME/code/vjepa2\n\n      Make sure to add the V-JEPA 2 directory to your ``PYTHONPATH`` as shown in Step 1.\n\n    - Then, download the V-JEPA 2 model weights from Hugging Face: `Sylvest/vjepa2-vit-g <https://huggingface.co/Sylvest/vjepa2-vit-g>`_\n    \n      .. code:: bash\n\n         huggingface-cli download Sylvest/vjepa2-vit-g --local-dir $HOME/models/vjepa2\n\nSet the paths to these resources in your environment or script:\n\n.. code:: bash\n\n   export MODEL_PATH=$HOME/models/Sylvest/OpenVLA-AC-PD-1traj-libero-long\n   export VJEPA_MODEL_PATH=$HOME/models/vjepa2/vitg-384.pt\n\n.. note::\n   \n   You do not need to manually prepare a dataset file. ``siiRL`` will automatically generate the task manifest (Parquet files) based on the environment configuration and save them to the path specified in ``TRAIN_DATA_PATH`` and ``TEST_DATA_PATH``.\n\nStep 3: Configure and Run the Training Script\n---------------------------------------------\n\nEmbodied AI training requires specific configurations to handle the environment interaction and action spaces.\n\nKey Configuration Parameters\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n**Embodied Specifics**:\n\n-   ``actor_rollout_ref.embodied.embodied_type``: The model type (e.g., ``openvla-oft``).\n-   ``actor_rollout_ref.embodied.action_token_len``: The dimensionality of the action space (e.g., 7 for xyz + quaternion + gripper).\n-   ``actor_rollout_ref.embodied.action_chunks_len``: The number of action steps predicted in one forward pass.\n-   ``actor_rollout_ref.embodied.video_embedding_model_path``: Path to the V-JEPA 2 video embedding model (e.g., ``$VJEPA_MODEL_PATH``).\n\n**Environment Configuration**:\n\n-   ``actor_rollout_ref.embodied.env.env_type``: The environment library (e.g., ``libero``).\n-   ``actor_rollout_ref.embodied.env.env_name``: The specific task suite name (e.g., ``libero_long``).\n-   ``actor_rollout_ref.embodied.env.num_envs``: Number of parallel environments per rollout worker. Default is 16 environments per GPU, and it is not recommended to exceed 16.\n-   ``actor_rollout_ref.embodied.env.max_steps``: Maximum steps per episode.\n\n**Algorithm Adjustments**:\n\n-   ``algorithm.embodied_sampling.filter_accuracy``: Enable filtering of prompts based on estimated success rate.\n-   ``algorithm.embodied_sampling.accuracy_lower_bound``: Lower threshold for filtering (e.g., 0.1).\n-   ``algorithm.embodied_sampling.accuracy_upper_bound``: Upper threshold for filtering (e.g., 0.9).\n\nComplete Training Script\n~~~~~~~~~~~~~~~~~~~~~~~~\n\nBelow is an example script `run_embodied_srpo.sh` to run SRPO training on `libero_long`.\n\n**Note**: The siiRL repository provides ready-to-use training scripts for all four LIBERO tasks in the `examples/embodied_srpo_trainer/` directory:\n\n-   ``run_openvla_oft_libero_long.sh``\n-   ``run_openvla_oft_libero_goal.sh``\n-   ``run_openvla_oft_libero_object.sh``\n-   ``run_openvla_oft_libero_spatial.sh``\n\nTo train on a specific task, modify the following paths in the script to match your actual environment:\n\n-   ``SIIRL_DIR``: Path to the siiRL repository\n-   ``VJEPA2_DIR``: Path to the V-JEPA2 repository (for ``PYTHONPATH``)\n-   ``HOME_PATH``: Your home directory or base path for models and data\n-   ``MODEL_PATH``: Path to the corresponding SFT model for the task\n-   ``VJEPA_MODEL_PATH``: Path to the V-JEPA 2 model weights file\n\n**Note**: LIBERO is pre-installed in the Docker image at ``/root/LIBERO/`` and does not need to be modified.\n\n.. code-block:: bash\n\n    #!/usr/bin/env bash\n    # ===================================================================================\n    # ===    Embodied AI SRPO Training with OpenVLA-OFT on LIBERO-LONG               ===\n    # ===================================================================================\n    # \n\n    set -e\n\n    # --- Environment Setup (Critical for siiRL) ---\n    export SIIRL_DIR=\"${SIIRL_DIR:-your_siirl_path}\"\n    export PYTHONPATH=\"$SIIRL_DIR:/root/LIBERO/:${VJEPA2_DIR:-your_vjepa2_path}:$PYTHONPATH\"\n\n    # --- Experiment and Model Definition ---\n    export DATASET=libero_long\n    export ALG=srpo\n    export MODEL_NAME=openvla-oft-7b\n    export MODEL_TYPE=openvla-oft\n\n    # --- Path Definitions (USER PROVIDED) ---\n    export HOME_PATH=${HOME_PATH:your_home_path}\n    export TRAIN_DATA_PATH=$HOME_PATH/data/train.parquet # generated automatically\n    export TEST_DATA_PATH=$HOME_PATH/data/test.parquet # generated automatically\n    export MODEL_PATH=$HOME_PATH/models/Sylvest/OpenVLA-AC-PD-1traj-libero-long\n    export VJEPA_MODEL_PATH=$HOME_PATH/models/vjepa2/vitg-384.pt\n\n    # Base output paths\n    export BASE_CKPT_PATH=ckpts\n    export BASE_TENSORBOARD_PATH=tensorboard\n\n    # --- Embodied AI Specific Parameters ---\n    export ACTION_TOKEN_LEN=7        # 7 dimensions: xyz (3), quaternion (3), gripper (1)\n    export ACTION_CHUNKS_LEN=8       # OpenVLA-OFT uses 8-step action chunks\n    export NUM_ENVS=16               # actor_rollout_ref.embodied.env.num_envs\n    export MAX_EPISODE_STEPS=512     # actor_rollout_ref.embodied.env.max_steps\n\n    # --- Data and Sampling Parameters ---\n    export VAL_BATCH_SIZE=496                      # Validation batch size\n    export MAX_PROMPT_LENGTH=256                   \n    export MAX_RESPONSE_LENGTH=128                 \n\n    # --- Embodied Sampling Parameters ---\n    export FILTER_ACCURACY=True                    # Enable accuracy-based filtering\n    export ACCURACY_LOWER_BOUND=0.1                # Only keep prompts with success rate >= 0.1\n    export ACCURACY_UPPER_BOUND=0.9                # Only keep prompts with success rate <= 0.9\n    export FILTER_TRUNCATED=False                  # Filter truncated episodes (uses env.max_steps)\n    export OVERSAMPLE_FACTOR=1                     # Oversample factor for filtering\n\n    # --- Training Hyperparameters ---\n    export TRAIN_BATCH_SIZE=64       # data.train_batch_size\n    export PPO_MINI_BATCH_SIZE=4     # actor_rollout_ref.actor.ppo_mini_batch_size\n                                     # Note: actual ppo_mini_batch_size = PPO_MINI_BATCH_SIZE * ROLLOUT_N_SAMPLES\n    export ROLLOUT_N_SAMPLES=8       # REUSED: Number of samples per prompt\n    export PPO_EPOCHS=1              # actor_rollout_ref.actor.ppo_epochs\n\n    # Algorithm parameters\n    export LEARNING_RATE=5e-6        \n    export WEIGHT_DECAY=0.0          # actor_rollout_ref.actor.optim.weight_decay\n    export CLIP_RATIO_HIGH=0.28      # actor_rollout_ref.actor.clip_ratio_high\n    export CLIP_RATIO_LOW=0.2        # actor_rollout_ref.actor.clip_ratio_low\n    export ENTROPY_COEFF=0.0         \n    export TEMPERATURE=1.6          \n    export GAMMA=1.0                 \n    export LAM=1.0                   \n    export GRAD_CLIP=1.0            \n\n    # --- Image/Video Processing ---\n    export IMG_SIZE=384              # actor_rollout_ref.embodied.img_size\n    export ENABLE_FP16=True          # actor_rollout_ref.embodied.enable_fp16\n    export EMBEDDING_MODEL_OFFLOAD=False  # actor_rollout_ref.embodied.embedding_model_offload\n    export CENTER_CROP=True          # actor_rollout_ref.embodied.center_crop\n    export NUM_IMAGES_IN_INPUT=1     \n    export NUM_STEPS_WAIT=10           # Environment stabilization steps\n\n    # --- Trainer Configuration ---\n    export SAVE_FREQ=5              \n    export TEST_FREQ=5              \n    export TOTAL_EPOCHS=1000         # trainer.total_epochs\n    export MAX_CKPT_KEEP=5           # trainer.max_actor_ckpt_to_keep\n    export VAL_BEFORE_TRAIN=True     # trainer.val_before_train\n\n    # --- Multi-node distributed training ---\n    export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\n    export NNODES=${PET_NNODES:-1}\n    export NODE_RANK=${PET_NODE_RANK:-0}\n    export MASTER_ADDR=${MASTER_ADDR:-localhost}\n    export MASTER_PORT=${MASTER_PORT:-29500}\n\n    # --- Environment Variables ---\n    export MUJOCO_GL=egl\n    export PYOPENGL_PLATFORM=egl\n    export GLOO_SOCKET_TIMEOUT=600\n\n    # --- Output Paths and Experiment Naming ---\n    timestamp=$(date +%Y%m%d_%H%M%S)\n    export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}nodes\n    export PROJECT_NAME=siirl_embodied_${DATASET}\n    export EXPERIMENT_NAME=openvla_oft_srpo_fsdp\n    export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}/${timestamp}\n    export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_${timestamp}\n\n    # --- Define the Training Command ---\n    TRAINING_CMD=(\n        python3 -m siirl.client.main_dag\n        --config-name=embodied_srpo_trainer\n        \n        # Data configuration\n        data.train_files=$TRAIN_DATA_PATH\n        data.val_files=$TEST_DATA_PATH\n        data.train_batch_size=$TRAIN_BATCH_SIZE\n        data.val_batch_size=$VAL_BATCH_SIZE\n        data.max_prompt_length=$MAX_PROMPT_LENGTH\n        data.max_response_length=$MAX_RESPONSE_LENGTH\n        \n        # Algorithm configuration\n        algorithm.workflow_type=embodied\n        algorithm.adv_estimator=grpo\n        algorithm.gamma=$GAMMA\n        algorithm.lam=$LAM\n        algorithm.norm_adv_by_std_in_grpo=True\n        \n        # Embodied sampling configuration (aligned with DAPO architecture)\n        algorithm.embodied_sampling.filter_accuracy=$FILTER_ACCURACY\n        algorithm.embodied_sampling.accuracy_lower_bound=$ACCURACY_LOWER_BOUND\n        algorithm.embodied_sampling.accuracy_upper_bound=$ACCURACY_UPPER_BOUND\n        algorithm.embodied_sampling.filter_truncated=$FILTER_TRUNCATED\n        algorithm.embodied_sampling.oversample_factor=$OVERSAMPLE_FACTOR\n        \n        # Model configuration\n        actor_rollout_ref.model.path=$MODEL_PATH\n        actor_rollout_ref.model.enable_gradient_checkpointing=True\n        \n        # Actor configuration\n        actor_rollout_ref.actor.optim.lr=$LEARNING_RATE\n        actor_rollout_ref.actor.optim.weight_decay=$WEIGHT_DECAY\n        actor_rollout_ref.actor.ppo_mini_batch_size=$PPO_MINI_BATCH_SIZE\n        actor_rollout_ref.actor.ppo_epochs=$PPO_EPOCHS\n        actor_rollout_ref.actor.grad_clip=$GRAD_CLIP\n        actor_rollout_ref.actor.clip_ratio_high=$CLIP_RATIO_HIGH\n        actor_rollout_ref.actor.clip_ratio_low=$CLIP_RATIO_LOW\n        actor_rollout_ref.actor.entropy_coeff=$ENTROPY_COEFF\n        actor_rollout_ref.actor.shuffle=True\n        \n        # Actor FSDP configuration\n        actor_rollout_ref.actor.fsdp_config.param_offload=False\n        actor_rollout_ref.actor.fsdp_config.grad_offload=False\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n        \n        # Rollout configuration\n        actor_rollout_ref.rollout.name=hf\n        actor_rollout_ref.rollout.n=$ROLLOUT_N_SAMPLES\n        actor_rollout_ref.rollout.temperature=$TEMPERATURE\n        actor_rollout_ref.rollout.do_sample=True\n        actor_rollout_ref.rollout.response_length=512\n        \n        # Embodied AI specific configuration\n        actor_rollout_ref.embodied.embodied_type=$MODEL_TYPE\n        actor_rollout_ref.embodied.action_token_len=$ACTION_TOKEN_LEN\n        actor_rollout_ref.embodied.action_chunks_len=$ACTION_CHUNKS_LEN\n        actor_rollout_ref.embodied.video_embedding_model_path=$VJEPA_MODEL_PATH\n        actor_rollout_ref.embodied.embedding_img_size=$IMG_SIZE\n        actor_rollout_ref.embodied.embedding_enable_fp16=$ENABLE_FP16\n        actor_rollout_ref.embodied.embedding_model_offload=$EMBEDDING_MODEL_OFFLOAD\n        actor_rollout_ref.embodied.center_crop=$CENTER_CROP\n        actor_rollout_ref.embodied.num_images_in_input=$NUM_IMAGES_IN_INPUT\n        actor_rollout_ref.embodied.unnorm_key=$DATASET\n        \n        # Environment configuration\n        actor_rollout_ref.embodied.env.env_type=libero\n        actor_rollout_ref.embodied.env.env_name=$DATASET\n        actor_rollout_ref.embodied.env.num_envs=$NUM_ENVS\n        actor_rollout_ref.embodied.env.max_steps=$MAX_EPISODE_STEPS\n        actor_rollout_ref.embodied.env.num_steps_wait=$NUM_STEPS_WAIT\n        actor_rollout_ref.embodied.env.num_trials_per_task=50\n        actor_rollout_ref.embodied.env.model_family=openvla\n        \n        # Critic configuration (SRPO doesn't use critic)\n        critic.use_critic_model=False\n        \n        # Trainer configuration\n        trainer.total_epochs=$TOTAL_EPOCHS\n        trainer.save_freq=$SAVE_FREQ\n        trainer.test_freq=$TEST_FREQ\n        trainer.max_actor_ckpt_to_keep=$MAX_CKPT_KEEP\n        trainer.logger=['console','tensorboard']\n        trainer.project_name=$PROJECT_NAME\n        trainer.experiment_name=$EXPERIMENT_NAME\n        trainer.nnodes=$NNODES\n        trainer.n_gpus_per_node=$N_GPUS_PER_NODE\n        trainer.default_local_dir=$CKPT_PATH\n        trainer.resume_mode=auto\n        trainer.val_before_train=$VAL_BEFORE_TRAIN\n    )\n\n    # ===================================================================================\n    # ===                          EXECUTION LOGIC                                    ===\n    # ===================================================================================\n\n    # --- Boilerplate Setup ---\n    set -e\n    set -o pipefail\n    set -x\n\n    # --- Infrastructure & Boilerplate Functions ---\n    start_ray_cluster() {\n        local RAY_HEAD_WAIT_TIMEOUT=600\n        export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n        export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n        export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n        export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n        local ray_start_common_opts=(\n            --num-gpus \"$N_GPUS_PER_NODE\"\n            --object-store-memory 100000000000\n            --memory 100000000000\n        )\n\n        if [ \"$NNODES\" -gt 1 ]; then\n            if [ \"$NODE_RANK\" = \"0\" ]; then\n                echo \"INFO: Starting Ray head node on $(hostname)...\"\n                export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n                ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n                local start_time=$(date +%s)\n                while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                    if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                    echo \"Head node not healthy yet. Retrying in 5s...\"\n                    sleep 5\n                done\n                echo \"INFO: Head node is healthy.\"\n            else\n                local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n                echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n                local start_time=$(date +%s)\n                while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                    if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                    echo \"Head not healthy yet. Retrying in 5s...\"\n                    sleep 5\n                done\n                echo \"INFO: Head is healthy. Worker starting...\"\n                ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n            fi\n        else\n            echo \"INFO: Starting Ray in single-node mode...\"\n            ray start --head \"${ray_start_common_opts[@]}\"\n        fi\n    }\n\n    # --- Main Execution Function ---\n    main() {\n        local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n        ray stop --force\n\n        export VLLM_USE_V1=1\n        export GLOO_SOCKET_TIMEOUT=600\n        export GLOO_TCP_TIMEOUT=600\n        export GLOO_LOG_LEVEL=DEBUG\n        export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n        export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n        export RAY_MASTER_ADDR=$MASTER_ADDR\n        \n        start_ray_cluster\n\n        if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"Waiting for all $NNODES nodes to join...\"\n            local TIMEOUT=600; local start_time=$(date +%s)\n            while true; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n                local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n                if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n                echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n                sleep 5\n            done\n            echo \"All $NNODES nodes have joined.\"\n        fi\n\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO [RANK 0]: Starting main training command.\"\n            eval \"${TRAINING_CMD[@]}\" \"$@\"\n            echo \"INFO [RANK 0]: Training finished.\"\n            sleep 30; ray stop --force >/dev/null 2>&1\n        elif [ \"$NNODES\" -gt 1 ]; then\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n            while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n            echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n        fi\n\n        echo \"INFO: Script finished on rank $NODE_RANK.\"\n    }\n\n    # --- Script Entrypoint ---\n    main \"$@\"\n\nStep 4: Checking the Results\n----------------------------\n\n1.  **Logs**: Monitor the console output for training progress and environment interaction stats.\n2.  **TensorBoard**: Use TensorBoard to visualize rewards, success rates, and other metrics.\n\n    .. code:: bash\n\n       tensorboard --logdir ./tensorboard\n\n3.  **Checkpoints**: Trained models are saved in the ``ckpts`` directory.\n\n"
  },
  {
    "path": "docs/examples/megatron_backend_example.rst",
    "content": "Megatron-LM Training Backend\n============================================\n\nIntroduction\n------------\n\nThis guide explains how to use the Megatron-LM backend in siiRL for RL training. Megatron-LM is a powerful library for training very large transformer models, and integrating it as a backend allows for efficient 5D parallelism (DP/TP/EP/PP/CP).\n\nThis example demonstrates how to fine-tune a `Qwen3-8B` model using the GRPO algorithm with the Megatron-LM as training backend.\n\nStep 1: Prepare the Dataset\n---------------------------\n\nFirst, ensure your dataset is in the required Parquet format. If you are using one of the example datasets like `gsm8k` or `deepscaler`, you can use the provided preprocessing scripts. For example, for `deepscaler`:\n\n.. code:: bash\n\n   cd examples/data_preprocess\n   python3 deepscaler.py --local_dir ~/data/deepscaler\n\nThis will download and process the dataset, saving `train.parquet` and `test.parquet` in the specified directory.\n\nStep 2: Download the Pre-trained Model\n--------------------------------------\n\nYou need a base model to start training. For this example, we'll use `Qwen3-8B`. Download it from Hugging Face or ModelScope to a local directory.\n\n.. code:: bash\n\n   # For Hugging Face\n   huggingface-cli download Qwen/Qwen3-8B-Instruct --local-dir ~/data/models/Qwen3-8B --local-dir-use-symlinks False\n   \n   # For ModelScope\n   modelscope download Qwen/Qwen3-8B-Instruct --local_dir ~/data/models/Qwen3-8B\n\nStep 3: Configure and Run the Training Script\n---------------------------------------------\n\nTo use the Megatron-LM backend, you need to modify the training configuration in your run script.\n\nKey Configuration Changes\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe main change is to set the training strategy to `megatron` and configure its parallelism parameters.\n\n1.  **Set the Strategy**: e.g., in the `TRAINING_CMD` array, set `actor_rollout_ref.actor.strategy=megatron`.\n2.  **Configure Parallelism**: Add Megatron-specific settings for 5D parallelism. For a 8B model on a single node with 8 GPUs, you might use 2-way tensor parallelism and 4-way pipeline parallelism, with sequence parallelism enabled.\n\n    .. code-block:: text\n\n        actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2\n        actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=4\n        actor_rollout_ref.actor.megatron.context_parallel_size=1\n        actor_rollout_ref.actor.megatron.sequence_parallel=True\n\n3.  **Configure Distributed Optimizer**: Add Megatron-specific settings for distributed optimizer. This allows for memory efficient training with ZeRO-1 optimization and is recommended for large models.\n\n    .. code-block:: text\n\n        actor_rollout_ref.actor.megatron.use_distributed_optimizer=True\n\n4.  **Configure Offloading**: Add Megatron-specific settings for parameter, gradient, and optimizer offload. This allows for parameter, gradient, and optimizer offloading to CPU to save GPU memory.\n\n    .. code-block:: text\n\n        actor_rollout_ref.actor.megatron.param_offload=True\n        actor_rollout_ref.actor.megatron.grad_offload=True\n        actor_rollout_ref.actor.megatron.optimizer_offload=True\n\nComplete Training Script\n~~~~~~~~~~~~~~~~~~~~~~~~\n\nBelow is a complete example script, `run_qwen3-8b-megatron.sh`, which is adapted from the standard GRPO script to use the Megatron backend. You will need to create this script yourself or adapt an existing one.\n\n.. code-block:: bash\n\n    #!/usr/bin/env bash\n    # ===================================================================================\n    # ===                       USER CONFIGURATION SECTION                            ===\n    # ===================================================================================\n\n    # --- For debugging\n    export HYDRA_FULL_ERROR=1\n    export SIIRL_LOG_VERBOSITY=INFO\n\n    # --- Experiment and Model Definition ---\n    export DATASET=deepscaler\n    export ALG=grpo\n    export MODEL_NAME=qwen3-8b\n\n    # --- Path Definitions ---\n    export HOME=${HOME:-\"/root\"} # Set your home path\n    export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\n    export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\n    export MODEL_PATH=$HOME/data/models/Qwen3-8B\n\n    # Base output paths\n    export BASE_CKPT_PATH=$HOME/ckpts\n    export BASE_TENSORBOARD_PATH=$HOME/tensorboard\n\n    # --- Key Training Hyperparameters ---\n    export TRAIN_BATCH_SIZE_PER_NODE=128\n    export PPO_MINI_BATCH_SIZE_PER_NODE=16\n    export PPO_MICRO_BATCH_SIZE_PER_GPU=8\n    export MAX_PROMPT_LENGTH=1024\n    export MAX_RESPONSE_LENGTH=2048\n    export ROLLOUT_GPU_MEMORY_UTILIZATION=0.45\n    export ROLLOUT_N=8\n    export SAVE_FREQ=30\n    export TEST_FREQ=10\n    export TOTAL_EPOCHS=30\n    export MAX_CKPT_KEEP=5\n\n    # ---- Megatron Parallelism Configuration ----\n    export ACTOR_REF_TP=2\n    export ACTOR_REF_PP=4\n    export ACTOR_REF_CP=1\n    export ACTOR_REF_SP=True\n\n    # --- Distributed Training & Infrastructure ---\n    export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\n    export NNODES=${PET_NNODES:-1}\n    export NODE_RANK=${PET_NODE_RANK:-0}\n    export MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n    # --- Output Paths and Experiment Naming ---\n    timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_megatron_${NNODES}nodes\n    export PROJECT_NAME=siirl_${DATASET}_${ALG}\n    export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_megatron_experiment\n    export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_megatron_tensorboard/dlc_${NNODES}_$timestamp\n    export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_megatron_${NNODES}_$timestamp\n\n    # --- Calculated Global Hyperparameters ---\n    export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\n    export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n    # --- Define the Training Command and its Arguments ---\n    TRAINING_CMD=(\n        python3 -m siirl.main_dag\n        algorithm.adv_estimator=\\$ALG\n        data.train_files=\\$TRAIN_DATA_PATH\n        data.val_files=\\$TEST_DATA_PATH\n        data.train_batch_size=\\$TRAIN_BATCH_SIZE\n        data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n        data.max_response_length=\\$MAX_RESPONSE_LENGTH\n        actor_rollout_ref.model.path=\\$MODEL_PATH\n        actor_rollout_ref.model.enable_gradient_checkpointing=True\n        \n        # --- Megatron Backend Configuration ---\n        actor_rollout_ref.actor.strategy=megatron\n        actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\\$ACTOR_REF_TP\n        actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\\$ACTOR_REF_PP\n        actor_rollout_ref.actor.megatron.context_parallel_size=\\$ACTOR_REF_CP\n        actor_rollout_ref.actor.megatron.sequence_parallel=\\$ACTOR_REF_SP\n        actor_rollout_ref.actor.megatron.use_distributed_optimizer=True\n        actor_rollout_ref.actor.megatron.param_dtype=bfloat16\n        actor_rollout_ref.actor.megatron.param_offload=False\n        \n        # --- PPO & Other Hyperparameters ---\n        actor_rollout_ref.actor.optim.lr=1e-6\n        actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n        actor_rollout_ref.actor.grad_clip=1.0\n        \n        # --- Rollout (vLLM) Configuration ---\n        actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ACTOR_REF_TP\n        actor_rollout_ref.rollout.name=vllm\n        actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n        actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n        actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n        actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n        \n        # --- Trainer Configuration ---\n        trainer.logger=['console','tensorboard']\n        trainer.project_name=\\$PROJECT_NAME\n        trainer.experiment_name=\\$EXPERIMENT_NAME\n        trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n        trainer.nnodes=\\$NNODES\n        trainer.save_freq=\\$SAVE_FREQ\n        trainer.test_freq=\\$TEST_FREQ\n        trainer.total_epochs=\\$TOTAL_EPOCHS\n        trainer.resume_mode=auto\n        trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n        trainer.default_local_dir=\\$CKPT_PATH\n        trainer.val_before_train=True\n    )\n\nStep 4: Checking the Results\n----------------------------\n\nDuring training, you can monitor the progress through several means:\n\n1.  **Console Logs**: The console will output detailed logs. Look for initialization messages from the Megatron backend to confirm it's being used. You should see logs pertaining to the setup of 5D parallelism.\n\n2.  **TensorBoard**: If you enabled the `tensorboard` logger, you can monitor training metrics in real-time.\n    \n    .. code:: bash\n\n       tensorboard --logdir $HOME/tensorboard\n\n    Navigate to the TensorBoard URL in your browser to view metrics such as reward, KL divergence, and loss curves.\n\n3.  **Checkpoints**: Checkpoints will be saved in the directory specified by `CKPT_PATH`. You can use these to resume training or for inference later.\n"
  },
  {
    "path": "docs/examples/mm_eureka_example.rst",
    "content": "MM-Eureka Example with GRPO\n===========================\n\nIntroduction\n------------\n\nThis guide details how to fine-tune a multi-modal Large Language Model using the **Group Relative Policy Optimization (GRPO)** algorithm on the **MM-Eureka** dataset. MM-Eureka is a challenging dataset designed to test mathematical reasoning that requires interpreting both text and images.\n\n**Paper:** https://arxiv.org/pdf/2503.07365.\n\n**Dataset:** https://huggingface.co/datasets/FanqingM/MM-Eureka-Dataset\n\nThe goal is to enhance a model's ability to perform complex reasoning by processing visual and textual information simultaneously. We use GRPO, an advanced RL algorithm, to optimize the model's policy.\n\nDataset Overview\n----------------\n\nMM-Eureka problems consist of a text-based question paired with one or more images. The model must understand the content of the image to solve the problem correctly.\n\n**An example from MM-Eureka:**\n\n**Prompt:**\n   .. image:: https://github.com/sii-research/siiRL/raw/main/docs/_static/cube.jpg\n      :width: 50%\n\n   Question: A cube loses one vertex after a 'corner' is removed. This geometric shape is ___ (fill in the number).\n\n**Answer:**\n   3\n\nStep 1: Data Preprocessing\n--------------------------\n\nThe raw MM-Eureka dataset, typically in `.jsonl` format, must be converted to Parquet. This involves not only structuring the text but also processing the associated images.\n\nThe script `examples/data_preprocess/mm_eureka.py` handles this. It performs the following actions:\n- Parses each line of the input JSONL file.\n\n- Reads the image file specified in `image_urls` and embeds its byte content directly into the Parquet file.\n\n- Formats the user prompts to include instructions for the desired output structure (`<think>...</think><answer>...</answer>`).\n\n- Splits the data into training and testing sets.\n\nRun the script with your dataset file:\n\n.. code:: bash\n\n   cd examples/data_preprocess\n   python3 mm_eureka.py --jsonl_file /path/to/your/mm_eureka_data.jsonl --output_dir ~/data/mm_eureka/\n\nStep 2: Defining the Reward Score\n---------------------------------\n\nA custom reward function is crucial for multi-modal reasoning. For MM-Eureka, we use a composite score defined in `siirl/utils/reward_score/mm_eureka.py`. This function evaluates two aspects of the model's response:\n\n1.  **Accuracy Reward**: This is the primary component. It parses the mathematical expression from the model's output (often in LaTeX) and compares it against the ground truth using the `math_verify` utility. This provides a robust check for mathematical correctness.\n2.  **Format Reward**: A smaller, secondary reward is given if the model correctly follows the required `<think>...</think><answer>...</answer>` structure. This encourages the model to generate well-formed, interpretable reasoning chains.\n\nThe final reward is a weighted sum of these two components (e.g., `0.9 * accuracy_reward + 0.1 * format_reward`), balancing correctness with style.\n\nStep 3: Download the Pre-trained Model\n--------------------------------------\n\nFor this multi-modal task, we use a powerful vision-language model like `Qwen2.5-VL-7B-Instruct`. Ensure the model is available locally for the training script.\n\n- **Recommended: Download via CLI:**\n\n  .. code:: bash\n\n     # For Hugging Face\n     huggingface-cli download Qwen/Qwen2.5-VL-7B-Instruct --local-dir ~/data/models/Qwen2.5-VL-7B-Instruct\n     \n     # For ModelScope\n     modelscope download Qwen/Qwen2.5-VL-7B-Instruct --local_dir ~/data/models/Qwen2.5-VL-7B-Instruct\n\n- **Automatic Download:** Alternatively, specify the model identifier directly in the run script's `actor_rollout_ref.model.path` field.\n\nStep 4: Perform GRPO Training\n-----------------------------\n\nWith the data and model prepared, you can launch the training job using the GRPO algorithm.\n\n**Training Script**\n\nThe script `examples/grpo_trainer/run_qwen2_5_vl-7b.sh` provides a complete configuration for this task. It sets up the environment, Ray cluster, and all necessary hyperparameters for GRPO training on the MM-Eureka dataset. Adapt the `HOME` path and other variables as needed for your environment.\n\n.. literalinclude:: ../../examples/grpo_trainer/run_qwen2_5_vl-7b.sh\n   :language: bash\n   :caption: examples/grpo_trainer/run_qwen2_5_vl-7b.sh "
  },
  {
    "path": "docs/hardware_tutorial/ascend_profiling_en.rst",
    "content": "Data Collection on Ascend Devices Based on the FSDP Backend\n============================================================\n\nLast updated: 08/14/2025.\n\nThis is a tutorial for using GRPO to collect data on Ascend devices based on the FSDP backend.\n\nConfiguration\n-------------\n\n- Global Collection Control: Use the configuration items in siirl/client/config/ppo_trainer.yaml to control the default collection mode.\n\nControl collection parameters using parameters in ppo_trainer.yaml:\n\n- enable: Whether to enable performance profiling.\n- save_path: The path to save collected data.\n\n- level: Collection level—options include level_none, level0, level1, and level2.\n- level_none: Disables all level-based data collection (turns off profiler_level).\n- level0: Collects high-level application data, low-level NPU data, and operator execution details on the NPU.\n- level1: Adds CANN layer AscendCL data and AI Core performance metrics on the NPU based on level0.\n- level2: Adds CANN layer Runtime data and AI CPU metrics based on level1.\n\n- with memory: Enables memory analysis (defaults to True).\n- record shapes: Enables recording of tensor shapes (defaults to False).\n- with npu: Enables collection of device-side performance data (defaults to True).\n- with cpu: Enables collection of host-side performance data (defaults to True).\n- with module: Enables recording of framework-level Python call stack information.\n- with stack: Enables recording of operator call stack information.\n- analysis: Enables automatic data analysis.\n- discrete: Enables discrete mode, collecting performance data for each stage separately (defaults to False).\n\n- roles: Collection stage - used in conjunction with the discrete parameter. Options include:\n\ngenerate, compute_reward, compute_old_log_prob, compute_ref_log_prob, compute_value, compute_advantage,\n\ntrain_critic, train_actor\n\n- all_ranks: Whether to collect data from all ranks.\n\n- ranks: List of ranks for which to collect data. If empty, no data is collected.\n\n- profile_steps: List of collection steps. For example, [2, 4] indicates that steps 2 and 4 will be collected. If set to null, no data is collected.\n\nExample\n-------\nDisable collection\n~~~~~~~~~~~~~~~~~~~~\n.. code:: yaml\n\n  profiler:\n    enable: False # disable profile\n\nEnd-to-end collection\n~~~~~~~~~~~~~~~~~~~~~\n\n.. code:: yaml\n\n  profiler:\n    steps: [1, 2, 5]\n    discrete: False\n\nThe run_qwen2_5-7b-npu-e2e_prof.sh script is provided in examples/grpo_trainer for reference.\n\nDiscrete mode collection\n~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code:: yaml\n\n  profiler:\n    discrete: True\n    roles:['generate', 'train_actor']\n\nThe discrete mode acquisition script run_qwen2_5-7b-npu-discrete_prof.sh is provided in examples/grpo_trainer for reference.\n\nVisualization\n-------------\n\nThe acquired data is stored in the user-defined save_path and can be visualized using the MindStudio Insight tool，\nyou can refer to <https://www.hiascend.com/document/detail/zh/mindstudio/80RC1/GUI_baseddevelopmenttool/msascendinsightug/Insight_userguide_0002.html>.\n\n\nIf the analysis parameter is set to False, offline analysis is required after collection:\n\n.. code:: python\n\n        import argparse\n        from torch_npu.profiler.profiler import analyse\n\n        parser = argparse.ArgumentParser()\n        parser.add_argument(\"--path\", type=str, default=\"facebook/opt-125m\")\n\n        if __name__ == \"__main__\":\n         args = parser.parse_args()\n         path = args.path\n"
  },
  {
    "path": "docs/hardware_tutorial/ascend_quickstart.rst",
    "content": "Ascend NPU\n==========\n\nSiiRL is also supports for Huawei's Ascend NPU devices. This guide has been tested with the following hardware:\n\n- Atlas 200T A2 Box16\n\nInstallation Process\n--------------------\n\nCore Environment Requirements\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nEnsure your environment meets these core software version requirements:\n\n+---------------------+------------+\n| Software            | Version    |\n+---------------------+------------+\n| Python              | == 3.10    |\n+---------------------+------------+\n| CANN                | == 8.1.RC1 |\n+---------------------+------------+\n| PyTorch             | == 2.5.1   |\n+---------------------+------------+\n| torch_npu           | == 2.5.1   |\n+---------------------+------------+\n| mindspeed(Optional) | == 0.12.1  |\n+---------------------+------------+\n\nRecommended Base Image\n^^^^^^^^^^^^^^^^^^^^^^\n\nFor a smoother setup, we strongly recommend using our pre-built Docker image, which includes all necessary dependencies. Please note this pre-built docker image contains torch, torch-npu, vLLM and vLLM-Ascend packages, after pulling it you only need to install siiRL framework from source.\n\n.. code-block:: bash\n\n    docker pull crispig/verl_npu:cann8.1rc1-py3.10-torch2.5.1-vllm-ascend0.7.3.post1-250616\n\nCompiling vLLM and vllm-ascend [Optional]\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nProper integration of vLLM within siiRL requires compiling both `vllm` and `vllm-ascend` from source. Follow the steps below, paying close attention to the instructions specific to your hardware.\n\n.. note::\n    We recommend using the latest version of vllm v0.9.2 and vllm-ascend v0.9.0rc2, which support setting use_remove_padding=True.\n\n.. code-block:: bash\n    \n    # vllm\n    git clone -b v0.9.2 --depth 1 https://github.com/vllm-project/vllm.git\n    cd vllm\n    pip install -r requirements-build.txt\n\n    # For Atlas 200T A2 Box16\n    VLLM_TARGET_DEVICE=empty pip install -e . --extra-index https://download.pytorch.org/whl/cpu/\n\n.. code-block:: bash\n    \n    # vllm-ascend\n    git clone -b v0.9.0rc2 --depth 1 https://github.com/vllm-project/vllm-ascend.git\n    cd vllm-ascend\n    export COMPILE_CUSTOM_KERNELS=1\n    python setup.py install\n\nSiiRL Installation\n^^^^^^^^^^^^^^^^^^\n\nFinally, install the siiRL framework itself. DO NOT use the pip install command to install siiRL, it will cause dependency conflicts.\n\n.. code-block:: bash\n\n    git clone https://github.com/sii-research/siiRL.git\n    cd siirl\n    pip install -e .\n\nThird-Party Library Considerations\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nPlease be aware of the following specific requirements and limitations for certain libraries on Ascend hardware:\n\n+--------------+---------------+\n| Software     | Description   |\n+--------------+---------------+\n| transformers | v4.52.4       |\n+--------------+---------------+\n| flash_attn   | not supported |\n+--------------+---------------+\n| liger-kernel | not supported |\n+--------------+---------------+\n| tensordict   | 0.8.3 (ARM)   |\n+--------------+---------------+\n\n1.  Using `--flash_attention_2` through `transformers` is supported (requires `transformers` version >= 4.52.0).\n2.  Flash Attention acceleration via the `flash_attn` package is not supported.\n3.  `liger-kernel` is not supported.\n4.  For ARM servers, `tensordict` version 0.8.3 is required. You can manually install it after the main dependencies are installed.\n5.  For x86 servers, the CPU version of `torchvision` must be installed.\n\n.. code-block:: bash\n\n    pip install torchvision==0.20.1+cpu --index-url https://download.pytorch.org/whl/cpu\n\nVerification with a Quick Start Example\n---------------------------------------\n\nTo ensure your setup is correct, we recommend performing a quick test run. The following example trains a Qwen2.5-0.5B model on the GSM8k dataset using the GRPO algorithm.\n\n1.  **Prepare the Dataset**\n    First, download and preprocess the GSM8k dataset. The provided script will convert it to the Parquet format required by the framework.\n\n.. code-block:: bash\n\n    python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k\n\n2.  **Run the Training Job**\n    Next, execute the training command below. Ensure you have set the `VLLM_ATTENTION_BACKEND` environment variable.\n\n.. code-block:: bash\n\n    set -x\n\n    python3 -m siirl.main_dag \\\n        algorithm.adv_estimator=grpo \\\n        data.train_files=/datasets/gsm8k/train.parquet\\\n        data.val_files=/datasets/gsm8k/teset.parquet \\\n        data.train_batch_size=1024 \\\n        data.max_prompt_length=1024 \\\n        data.max_response_length=1024 \\\n        data.filter_overlong_prompts=True \\\n        data.truncation='error' \\\n        actor_rollout_ref.model.path=/models/Qwen2.5-0.5B-Instruct \\\n        actor_rollout_ref.actor.optim.lr=5e-8 \\\n        actor_rollout_ref.model.use_remove_padding=False \\\n        actor_rollout_ref.actor.ppo_mini_batch_size=32 \\\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \\\n        actor_rollout_ref.actor.use_kl_loss=True \\\n        actor_rollout_ref.actor.entropy_coeff=0 \\\n        actor_rollout_ref.actor.kl_loss_coef=0.001 \\\n        actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n        actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n        actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \\\n        actor_rollout_ref.rollout.tensor_model_parallel_size=4 \\\n        actor_rollout_ref.rollout.name=vllm \\\n        actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \\\n        actor_rollout_ref.rollout.n=5 \\\n        actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \\\n        actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n        algorithm.use_kl_in_reward=False \\\n        trainer.critic_warmup=0 \\\n        trainer.logger=['console'] \\\n        trainer.project_name='siirl_grpo_example_gsm8k' \\\n        trainer.experiment_name='qwen2_05b_function_rm' \\\n        trainer.n_gpus_per_node=16 \\\n        trainer.nnodes=$NNODES \\\n        trainer.save_freq=-1 \\\n        trainer.test_freq=5 \\\n        trainer.total_epochs=300 \\\n        trainer.device=npu $@\n\n(Optional) Setting Up MindSpeed Training Backend Guide\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nRefer to the MindSpeed README <https://gitee.com/ascend/MindSpeed>_ for instructions on installing the MindSpeed acceleration library, recommended versions: MindSpeed Core 0.12.1, Megatron-LM 0.12.2.\n\n.. warning::\n\n   Please Be sure to install **megatron-core** via ``pip install``.  \n   Using ``PYTHONPATH`` to point to megatron will crash the program.\n\nEnable siirl worker model ``strategy`` and set it to ``megatron``. For example: ``actor_rollout_ref.actor.strategy=megatron``.\n\nCustom MindSpeed parameters can be passed through the override_transformer_config option. For instance, to enable FA for the actor model, you can use:\n``+actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True``.\n\nMindSpeed provides the same support for siiRL and verl. For more feature details, please refer to the MindSpeed+verl documentation. <https://gitee.com/ascend/MindSpeed/blob/master/docs/user-guide/verl.md>_.\n"
  },
  {
    "path": "docs/hardware_tutorial/metax_quickstart.rst",
    "content": "MetaX(沐曦) GPU\n===============\n\nSiiRL is also supports for MetaX's GPU devices. This guide has been tested with the following hardware:\n\n- 曦云 series GPU\n\nInstallation Process\n--------------------\n\nRecommended Base Image\n^^^^^^^^^^^^^^^^^^^^^^\n\nFor a smoother setup, we strongly recommend using our pre-built Docker image, which includes all necessary dependencies. Please refer to MetaX developer website: https://developer.metax-tech.com/softnova/docker, after pulling it you only need to install siiRL framework from source.\n\n.. code-block:: bash\n\n    docker pull siiai/siirl-metax:maca.ai3.1.0.1-torch2.6-py310-ubuntu22.04-amd64\n\nStart docker container\n^^^^^^^^^^^^^^^^^^^^^^\n\n.. code-block:: bash\n    \n    docker run -d -t --net=host --uts=host --ipc=host --privileged=true --group-add video \\\n    --shm-size 100gb --ulimit memlock=-1 --security-opt seccomp=unconfined \\\n    --security-opt apparmor=unconfined --device=/dev/dri --device=/dev/mxcd --device=/dev/infiniband \\\n    -v /data/:/data/ \\\n    --name siirl \\\n    siiai/siirl-metax:maca.ai3.1.0.1-torch2.6-py310-ubuntu22.04-amd64 bash\n\nSiiRL Installation\n^^^^^^^^^^^^^^^^^^\n\nFinally, install the siiRL framework itself. DO NOT use the pip install command to install siiRL, it will cause dependency conflicts.\n\n.. code-block:: bash\n\n    git clone https://github.com/sii-research/siiRL.git\n    cd siirl\n    # You need to comment out the libraries adapted for MetaX, such as ray and vllm, to prevent them from being overwritten.\n    # vllm>=0.8.5.post1\n    # ray[default]>=2.47.1\n    pip install -r requirements.txt\n    pip install -e .\n\nAdd environment variables for MetaX\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. code-block:: bash\n\n    # mx gpu env\n    export MACA_PATH=/opt/maca\n    export CUCC_PATH=${MACA_PATH}/tools/cu-bridge\n    export CUDA_PATH=${CUCC_PATH}\n    export MACA_CLANG_PATH=$MACA_PATH/mxgpu_llvm/bin\n    export PATH=${CUDA_PATH}/bin:${MACA_CLANG_PATH}:${PATH}\n    export LD_LIBRARY_PATH=${MACA_PATH}/tools/cu-bridge/lib/:${MACA_PATH}/lib:${MACA_PATH}/ompi/lib:${MACA_PATH}/mxgpu_llvm/lib:${LD_LIBRARY_PATH}\n    export PYTORCH_ENABLE_SAME_RAND_A100=1\n    export MCPYTORCH_DISABLE_PRINT=1\n    export MAX_JOBS=20\n    export VLLM_USE_V1=0\n    export MCCL_ENABLE_FC=0\n    export MCCL_MAX_NCHANNELS=8\n    export PYTHONUNBUFFERED=1\n\n    export MCCL_IB_HCA=mlx5\n    export MCCL_SOCKET_IFNAME=ens1\n    export GLOO_SOCKET_IFNAME=ens1\n    export SOCKET_NIC=ens1\n\nVerification with a Quick Start Example\n---------------------------------------\n\nTo ensure your setup is correct, we recommend performing a quick test run. The following example trains a Qwen2.5-0.5B model on the GSM8k dataset using the GRPO algorithm.\n\n1.  **Prepare the Dataset**\n    First, download and preprocess the GSM8k dataset. The provided script will convert it to the Parquet format required by the framework.\n\n.. code-block:: bash\n\n    python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k\n\n2.  **Run the Training Job**\n    Next, execute the training command below. Ensure you have set the `VLLM_ATTENTION_BACKEND` environment variable.\n\n.. code-block:: bash\n\n    # --- Experiment and Model Definition ---\n    export DATASET=gsm8k\n    export ALG=grpo\n    export MODEL_NAME=qwen2.5-05b\n\n    # --- Path Definitions ---\n    export HOME=/data/\n    export TRAIN_DATA_PATH=$HOME/$DATASET/train.parquet\n    export TEST_DATA_PATH=$HOME/$DATASET/test.parquet\n    export MODEL_PATH=$HOME/Qwen2.5-0.5B-Instruct\n\n    # Base output paths\n    export BASE_CKPT_PATH=ckpts\n    export BASE_TENSORBOARD_PATH=tensorboard\n\n    # --- Key Training Hyperparameters ---\n    export TRAIN_BATCH_SIZE_PER_NODE=512\n    export PPO_MINI_BATCH_SIZE_PER_NODE=256\n    export PPO_MICRO_BATCH_SIZE_PER_GPU=8\n    export MAX_PROMPT_LENGTH=1024\n    export MAX_RESPONSE_LENGTH=2048\n    export ROLLOUT_GPU_MEMORY_UTILIZATION=0.4\n    export ROLLOUT_TP=2\n    export ROLLOUT_N=8\n    export SAVE_FREQ=30\n    export TEST_FREQ=10\n    export TOTAL_EPOCHS=30\n    export MAX_CKPT_KEEP=5\n\n    # --- Multi-node (Multi-machine) distributed training environments ---\n\n    # Uncomment the following line and set the correct network interface if needed for distributed backend\n\n    # --- Distributed Training & Infrastructure ---\n    export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\n    export NNODES=${PET_NNODES:-1}\n    export NODE_RANK=${PET_NODE_RANK:-0}\n    export MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n    # --- Output Paths and Experiment Naming ---\n    export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\n    export PROJECT_NAME=siirl_${DATASET}_${ALG}\n    export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\n    export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\n    export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n    # --- Calculated Global Hyperparameters ---\n    export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\n    export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n    # mx gpu env\n    export MACA_PATH=/opt/maca\n    export CUCC_PATH=${MACA_PATH}/tools/cu-bridge\n    export CUDA_PATH=${CUCC_PATH}\n    export MACA_CLANG_PATH=$MACA_PATH/mxgpu_llvm/bin\n    export PATH=${CUDA_PATH}/bin:${MACA_CLANG_PATH}:${PATH}\n    export LD_LIBRARY_PATH=${MACA_PATH}/tools/cu-bridge/lib/:${MACA_PATH}/lib:${MACA_PATH}/ompi/lib:${MACA_PATH}/mxgpu_llvm/lib:${LD_LIBRARY_PATH}\n    export PYTORCH_ENABLE_SAME_RAND_A100=1\n    export MCPYTORCH_DISABLE_PRINT=1\n    export MAX_JOBS=20\n    export VLLM_USE_V1=0\n    export MCCL_ENABLE_FC=0\n\n    export MCCL_MAX_NCHANNELS=8\n    export PYTHONUNBUFFERED=1\n    export MCCL_IB_HCA=mlx5\n    export MCCL_SOCKET_IFNAME=ens1\n    export GLOO_SOCKET_IFNAME=ens1\n    export SOCKET_NIC=ens1\n\n    # --- Define the Training Command and its Arguments ---\n    TRAINING_CMD=(\n        python3 -m siirl.main_dag\n        algorithm.adv_estimator=\\$ALG\n        data.train_files=\\$TRAIN_DATA_PATH\n        data.val_files=\\$TEST_DATA_PATH\n        data.train_batch_size=\\$TRAIN_BATCH_SIZE\n        data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n        data.max_response_length=\\$MAX_RESPONSE_LENGTH\n        data.filter_overlong_prompts=True\n        data.truncation='error'\n        data.shuffle=False\n        actor_rollout_ref.model.path=\\$MODEL_PATH\n        actor_rollout_ref.actor.optim.lr=1e-6\n        actor_rollout_ref.model.use_remove_padding=True\n        actor_rollout_ref.model.use_fused_kernels=False\n        actor_rollout_ref.actor.policy_drift_coeff=0.001\n        actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n        actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n        actor_rollout_ref.actor.use_kl_loss=True\n        actor_rollout_ref.actor.grad_clip=0.5\n        actor_rollout_ref.actor.clip_ratio=0.2\n        actor_rollout_ref.actor.kl_loss_coef=0.01\n        actor_rollout_ref.actor.kl_loss_type=low_var_kl\n        actor_rollout_ref.model.enable_gradient_checkpointing=True\n        actor_rollout_ref.actor.fsdp_config.param_offload=True\n        actor_rollout_ref.actor.fsdp_config.optimizer_offload=True\n        actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n        actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n        actor_rollout_ref.rollout.name=vllm\n        actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n        actor_rollout_ref.rollout.max_model_len=\\$MAX_RESPONSE_LENGTH\n        actor_rollout_ref.rollout.enable_chunked_prefill=False\n        actor_rollout_ref.rollout.enforce_eager=False\n        actor_rollout_ref.rollout.free_cache_engine=False\n        actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n        actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n        actor_rollout_ref.ref.fsdp_config.param_offload=True\n        algorithm.weight_factor_in_cpgd='STD_weight'\n        algorithm.kl_ctrl.kl_coef=0.001\n        trainer.critic_warmup=0\n        trainer.logger=['console','tensorboard']\n        trainer.project_name=\\$PROJECT_NAME\n        trainer.experiment_name=\\$EXPERIMENT_NAME\n        trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n        trainer.nnodes=\\$NNODES\n        trainer.save_freq=\\$SAVE_FREQ\n        trainer.test_freq=\\$TEST_FREQ\n        trainer.total_epochs=\\$TOTAL_EPOCHS\n        trainer.resume_mode=auto\n        trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n        trainer.default_local_dir=\\$CKPT_PATH\n        trainer.val_before_train=False\n    )\n\n    # ===================================================================================\n    # ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n    # ===================================================================================\n\n    # --- Boilerplate Setup ---\n    set -e\n    set -o pipefail\n    set -x\n\n    # --- Infrastructure & Boilerplate Functions ---\n    start_ray_cluster() {\n        local RAY_HEAD_WAIT_TIMEOUT=600\n        export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n        export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n        export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n        export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n        local ray_start_common_opts=(\n            --num-gpus \"$N_GPUS_PER_NODE\"\n            --object-store-memory 100000000000\n            --memory 100000000000\n        )\n\n        if [ \"$NNODES\" -gt 1 ]; then\n            if [ \"$NODE_RANK\" = \"0\" ]; then\n                echo \"INFO: Starting Ray head node on $(hostname)...\"\n                export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n                ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n                local start_time=$(date +%s)\n                while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                    if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                    echo \"Head node not healthy yet. Retrying in 5s...\"\n                    sleep 5\n                done\n                echo \"INFO: Head node is healthy.\"\n            else\n                local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n                echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n                local start_time=$(date +%s)\n                while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                    if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                    echo \"Head not healthy yet. Retrying in 5s...\"\n                    sleep 5\n                done\n                echo \"INFO: Head is healthy. Worker starting...\"\n                ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n            fi\n        else\n            echo \"INFO: Starting Ray in single-node mode...\"\n            ray start --head \"${ray_start_common_opts[@]}\"\n        fi\n    }\n\n    # --- Main Execution Function ---\n    main() {\n        local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n        ray stop --force\n\n        # export VLLM_USE_V1=0\n        export GLOO_SOCKET_TIMEOUT=600\n        export GLOO_TCP_TIMEOUT=600\n        export GLOO_LOG_LEVEL=DEBUG\n        export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n        export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n        export RAY_MASTER_ADDR=$MASTER_ADDR\n        \n        start_ray_cluster\n\n        if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"Waiting for all $NNODES nodes to join...\"\n            local TIMEOUT=600; local start_time=$(date +%s)\n            while true; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n                local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n                if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n                echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n                sleep 5\n            done\n            echo \"All $NNODES nodes have joined.\"\n        fi\n\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO [RANK 0]: Starting main training command.\"\n            eval \"${TRAINING_CMD[@]}\" \"$@\"\n            echo \"INFO [RANK 0]: Training finished.\"\n            sleep 30; ray stop --force >/dev/null 2>&1\n        elif [ \"$NNODES\" -gt 1 ]; then\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n            while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n            echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n        fi\n\n        echo \"INFO: Script finished on rank $NODE_RANK.\"\n    }\n\n    # --- Script Entrypoint ---\n    main \"$@\"\n    !/usr/bin/env bash\n\n"
  },
  {
    "path": "docs/index.rst",
    "content": ".. siiRL documentation master file, created by\n   sphinx-quickstart on Wed Jul  9 15:26:45 2025.\n   You can adapt this file completely to your liking, but it should at least\n   contain the root `toctree` directive.\n\nsiiRL documentation\n===================\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Quickstart\n\n   start/install\n   start/quickstart\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Programming guide\n\n   programming_guide/siirl_architecture_guide\n   programming_guide/code_structure\n   programming_guide/siiRL_code_explained\n   programming_guide/srpo_code_explained\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Data Preparation\n\n   preparation/prepare_data\n   preparation/reward_function\n\n.. toctree::\n   :maxdepth: 2\n   :caption: User Define Interface\n\n   user_interface/filter_interface\n   user_interface/reward_interface\n   user_interface/pipeline_interface\n   user_interface/metrics_interface\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Configurations\n\n   examples/config\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Example\n\n   examples/deepscaler_example\n   examples/mm_eureka_example\n   examples/cpgd_example\n   examples/megatron_backend_example\n   examples/embodied_srpo_example\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Hardware Support\n\n   hardware_tutorial/ascend_quickstart\n   hardware_tutorial/ascend_profiling_en\n   hardware_tutorial/metax_quickstart\n"
  },
  {
    "path": "docs/preparation/prepare_data.rst",
    "content": "Prepare Data for Post-Training\n========================================\n\nBefore starting the post-training job, we need to prepare the data for policy training. The data should be preprocessed and stored in Parquet format, which facilitates efficient distributed data loading and processing.\n\nWe provide several data preprocessing scripts for popular datasets under the ``examples/data_preprocess/`` directory, such as ``gsm8k.py``, ``math_dataset.py``, and ``deepscaler.py``. To support a new custom dataset, you will need to create a similar script.\n\nThis document uses the ``DeepScaleR`` dataset as an example to detail the data preparation process and its specifications.\n\nGeneral Data Preprocessing Workflow\n-----------------------------------\n\nA typical data preprocessing script involves the following steps:\n\n1.  **Load Raw Data**: Use a library like Hugging Face's ``datasets`` to load the original dataset from the Hub or local files.\n2.  **Define Processing Logic**: Implement a core mapping function (which we often name ``make_map_fn``) to convert each sample from the original dataset into the specific format required by our framework.\n3.  **Apply Transformation and Save**: Use the ``datasets.map()`` method to apply this function to the entire dataset. Then, save the processed result in Parquet format locally, with an option to upload it to a distributed file system like HDFS.\n\nHere is a simplified framework of the process:\n\n.. code:: python\n\n   import argparse\n   import os\n   import datasets\n   from siirl.utils.extras.hdfs_io import copy, makedirs\n\n   def make_map_fn(split_name):\n       # ... Define your data processing logic here ...\n       def process_fn(example, idx):\n           # ... Transform each data sample ...\n           return transformed_data\n       return process_fn\n\n   if __name__ == '__main__':\n       parser = argparse.ArgumentParser()\n       # ... Define arguments ...\n       args = parser.parse_args()\n\n       # 1. Load data\n       raw_dataset = datasets.load_dataset(...)\n       \n       # 2. Apply transformation\n       processed_dataset = raw_dataset.map(function=make_map_fn('train'), with_indices=True)\n\n       # 3. Save as Parquet\n       local_dir = args.local_dir\n       processed_dataset.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n\n       # (Optional) Upload to HDFS\n       if args.hdfs_dir:\n           makedirs(args.hdfs_dir)\n           copy(src=local_dir, dst=args.hdfs_dir)\n\n\nDeepScaleR Dataset Processing in Practice\n-------------------------------------------\n\nLet's take ``examples/data_preprocess/deepscaler.py`` as a concrete example to demonstrate how to process the ``agentica-org/DeepScaleR-Preview-Dataset``.\n\nThe core task is to implement the ``make_map_fn`` function, which maps original fields (like ``problem``, ``answer``, and ``solution``) to the standard format required by the training framework.\n\n.. code:: python\n\n   data_source = \"agentica-org/DeepScaleR-Preview-Dataset\"\n   instruction_following = 'Let\\'s think step by step and output the final within \\\\boxed{}.'\n\n   def make_map_fn(split_name):\n\n       def process_fn(example, idx):\n           question_raw = example.pop(\"problem\") \n           answer_raw = example.pop(\"answer\") \n\n           question = question_raw + \" \" + instruction_following \n           solution = example.pop(\"solution\") \n           data = {\n               \"data_source\": data_source,\n               \"prompt\": [\n                   {\n                   \"role\": \"user\",\n                       \"content\": question,\n                   }\n               ],\n               \"ability\": \"math\",\n               \"reward_model\": {\"style\": \"rule\", \"ground_truth\": answer_raw},\n               \"extra_info\": {\n                   \"split\": split_name,\n                   \"index\": idx,\n                   \"answer\": solution, \n                   \"question\": question_raw, \n               },\n           }\n           \n           return data\n\n       return process_fn\n\nData Format Specification\n-------------------------\n\nTo ensure the framework can correctly parse and utilize the data, each sample processed by ``make_map_fn`` must contain the following five key fields:\n\n1.  ``data_source``: A string indicating the source or name of the dataset. This field is used to dynamically select the corresponding reward function during training.\n    - Example: ``\"agentica-org/DeepScaleR-Preview-Dataset\"``\n\n2.  ``prompt``: A list used to construct the model's input, formatted to be compatible with Hugging Face's Chat Template. The data loader will automatically apply the template and tokenize the input.\n    - Example: ``[{\"role\": \"user\", \"content\": \"What is 2+2? Let's think step by step...\"}]``\n\n3.  ``ability``: A string defining the domain or capability of the current task, such as ``\"math\"``, ``\"coding\"``, or ``\"general\"``.\n\n4.  ``reward_model``: A dictionary containing information needed to calculate the reward. Currently, the ``ground_truth`` field is primarily used during evaluation.\n    - **Note**: The ``ground_truth`` you provide must align with the logic of the corresponding reward function you implement. For a math problem, it might be the standard answer; for code generation, it could be a set of unit tests.\n    - Example: ``{\"style\": \"rule\", \"ground_truth\": \"\\\\boxed{4}\"}``\n\n5.  ``extra_info``: A dictionary for storing additional metadata, such as the original dataset split (train/test) or sample index. This information is not used directly in training but is useful for debugging and data traceability.\n\nBy following these specifications, you can prepare your dataset to be used smoothly within the SiiRL post-training pipeline."
  },
  {
    "path": "docs/preparation/reward_function.rst",
    "content": "Implementing Reward Functions for Datasets\n===========================================\n\nIn Reinforcement Learning for LLMs, the reward function is a critical component that guides the model's learning process. It quantitatively evaluates the quality of a generated response, signaling what constitutes a \"good\" or \"bad\" output. Our framework provides a flexible system for defining these rewards, supporting both pre-implemented logic for common datasets and fully customized functions for specific tasks.\n\nThe RewardManager\n-----------------\n\nThe ``RewardManager`` is the central hub for reward computation. As defined in `siirl/scheduler/reward.py`, its primary role is to orchestrate the scoring of generated responses by invoking a specified scoring function. Different managers, like `NaiveRewardManager` or `BatchRewardManager`, offer different strategies for handling this process. This design is consistent with the `verl` framework's architecture. [1]_\n\nThe typical workflow is as follows:\n1. The manager receives a `DataProto` object, which is a batch containing all necessary information.\n2. It extracts relevant fields, such as the model's generated text (`solution_strs`) and the reference answer (`ground_truth`).\n3. It passes this data to a designated scoring function (`compute_score_fn`) to calculate the reward for each item in the batch.\n\nThis design allows the core training loop to remain agnostic to the specifics of reward calculation, which are neatly encapsulated within the manager and its scoring function.\n\nReward Function Implementations\n-------------------------------\n\nYou can define reward logic in two ways: by using our pre-built functions or by creating your own.\n\nPre-implemented Functions\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\nFor standard benchmarks, we provide ready-to-use reward functions in the `siirl/utils/reward_score/` directory. These cover datasets like `GSM8K` and `MATH`, implementing their standard evaluation logic. For instance, the `GSM8K` scorer extracts the final numerical answer and compares it to the ground truth.\n\nCustomized Functions\n~~~~~~~~~~~~~~~~~~~~\n\nFor novel tasks or custom evaluation criteria, you can supply your own reward function. This is configured via two parameters: `custom_reward_function.path` and `custom_reward_function.name`.\n\nLet's consider a practical example from the `run_qwen2_5-7b-custom_reward.sh` script, which uses a batch-processing reward function for efficiency.\n\n**1. Configuration in the script:**\n\nThe script specifies the path to the custom code, the function to use, and selects the `BatchRewardManager` to execute it.\n\n.. code-block:: bash\n\n   # ... other configurations ...\n   python3 -m siirl.main_dag \\\n       # ...\n       custom_reward_function.path=$HOME/rl/rewardfunc_gsm8k.py \\\n       custom_reward_function.name=compute_score \\\n       reward_model.reward_manager=batch \\\n       # ...\n\n**2. Implementation of the reward function:**\n\nThe corresponding `rewardfunc_gsm8k.py` file implements the `compute_score` function. This function is designed to process an entire batch of solutions at once, which is significantly more efficient than processing them one by one.\n\n.. code:: python\n\n   import re\n\n   def extract_solution(solution_str, method=\"strict\"):\n       # ... (logic to extract the final answer from text)\n       # For example, finds the number after \"####\"\n       if method == \"strict\":\n           solution = re.search(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n           if solution is None: return None\n           final_answer = solution.group(0).split(\"#### \")[1].replace(\",\", \"\")\n           return final_answer\n       # ... other extraction logic ...\n\n   def compute_score(data_sources, solution_strs, ground_truths, extra_infos, method=\"strict\", score=1.0, **kwargs):\n       \"\"\"\n       Computes scores for a batch of solutions.\n       \"\"\"\n       scores = []\n       for solution_str, ground_truth in zip(solution_strs, ground_truths):\n           answer = extract_solution(solution_str=solution_str, method=method)\n           if answer is not None and answer == ground_truth:\n               scores.append(score)\n           else:\n               scores.append(0.0)\n       return scores\n\nThe function signature should accept lists of `solution_strs` and `ground_truths`. You can also pass custom parameters from your configuration, like `method` or `score`, by defining them under `custom_reward_function.reward_kwargs`. This allows you to easily experiment with different reward schemes without changing the code.\n\n.. [1] https://verl.readthedocs.io/en/latest/preparation/reward_function.html"
  },
  {
    "path": "docs/programming_guide/code_structure.rst",
    "content": "===============\nCode Structure\n===============\n\nThis document describes the code structure and architecture of siiRL.\n\nDirectory Structure\n-------------------\n\n.. code-block:: text\n\n   siirl/\n   ├── main_dag.py                   # Main entry point\n   ├── dag_worker/                   # DAG Worker implementation\n   ├── execution/                    # Execution engine\n   ├── engine/                       # Model engine\n   ├── data_coordinator/             # Data coordination\n   ├── params/                       # Configuration parameters\n   ├── environment/                  # Environment abstraction\n   └── user_interface/               # User interface\n\nCore Modules\n------------\n\ndag_worker/\n~~~~~~~~~~~\n\nDAG execution unit, one worker per GPU.\n\n.. code-block:: text\n\n   dag_worker/\n   ├── dagworker.py              # Core Worker class (~1320 lines)\n   ├── core_algos.py             # RL algorithm implementations\n   ├── dag_utils.py              # Utility functions\n   ├── checkpoint_manager.py     # Checkpoint management\n   ├── metrics_collector.py      # Metrics collection\n   ├── metric_aggregator.py      # Metrics aggregation\n   ├── validator.py              # Validation logic\n   ├── constants.py              # Constants\n   └── data_structures.py        # Data structures\n\n**Responsibilities:**\n\n- Execute TaskGraph nodes\n- Manage model Workers (Actor/Critic/Rollout/Reference/Reward)\n- Data flow and caching\n- Metrics collection and reporting\n- Checkpoint saving and loading\n\nexecution/\n~~~~~~~~~~\n\nExecution engine for DAG definition, scheduling, and metrics aggregation.\n\n.. code-block:: text\n\n   execution/\n   ├── dag/                      # DAG definition\n   │   ├── task_graph.py         # TaskGraph class\n   │   ├── node.py               # Node class\n   │   ├── builtin_pipelines.py  # Built-in Pipelines\n   │   ├── pipeline.py           # Pipeline Builder API\n   │   ├── config_loader.py      # Configuration loader\n   │   └── task_loader.py        # Task loader\n   ├── scheduler/                # Task scheduling\n   │   ├── task_scheduler.py     # Task scheduler\n   │   ├── launch.py             # Ray launcher\n   │   ├── process_group_manager.py  # Process group manager\n   │   ├── graph_updater.py      # Graph updater\n   │   ├── reward.py             # Reward scheduler\n   │   ├── enums.py              # Enum definitions\n   │   └── resource_manager.py   # Resource manager\n   ├── metric_worker/            # Distributed metrics aggregation\n   │   ├── metric_worker.py      # MetricWorker\n   │   └── utils.py\n   └── rollout_flow/             # Rollout flow\n       ├── multi_agent/          # Multi-agent support\n       └── multiturn/            # Multi-turn interaction\n\n**Responsibilities:**\n\n- DAG definition and validation\n- Task scheduling and resource allocation\n- Distributed metrics collection\n- Multi-agent/multi-turn interaction flow\n\nengine/\n~~~~~~~\n\nModel execution engine containing all model workers.\n\n.. code-block:: text\n\n   engine/\n   ├── actor/                    # Actor models\n   │   ├── base.py\n   │   ├── dp_actor.py           # FSDP Actor\n   │   ├── megatron_actor.py     # Megatron Actor\n   │   └── embodied_actor.py     # Embodied Actor\n   ├── critic/                   # Critic models\n   │   ├── base.py\n   │   ├── dp_critic.py\n   │   └── megatron_critic.py\n   ├── rollout/                  # Rollout engine\n   │   ├── base.py\n   │   ├── vllm_rollout/         # vLLM backend\n   │   ├── sglang_rollout/       # SGLang backend\n   │   ├── hf_rollout.py         # HuggingFace backend\n   │   └── embodied_rollout.py   # Embodied Rollout\n   ├── reward_model/             # Reward models\n   ├── reward_manager/           # Reward managers\n   │   ├── naive.py              # Simple reward\n   │   ├── batch.py              # Batch Reward Model\n   │   ├── parallel.py           # Parallel Reward Model\n   │   ├── dapo.py               # DAPO Reward\n   │   └── embodied.py           # Embodied Reward\n   ├── sharding_manager/         # Sharding management\n   ├── base_worker/              # Worker base classes\n   ├── fsdp_workers.py           # FSDP Worker\n   └── megatron_workers.py       # Megatron Worker\n\n**Responsibilities:**\n\n- Training and inference for Actor/Critic/Rollout/Reference/Reward models\n- Support for FSDP and Megatron backends\n- Support for vLLM/SGLang/HuggingFace inference backends\n\ndata_coordinator/\n~~~~~~~~~~~~~~~~~\n\nData coordinator for distributed data management.\n\n.. code-block:: text\n\n   data_coordinator/\n   ├── data_buffer.py            # Distributed data buffer\n   ├── dataloader/               # Data loading\n   │   ├── data_loader_node.py\n   │   ├── partitioned_dataset.py\n   │   ├── embodied_preprocess.py\n   │   └── vision_utils.py\n   ├── protocol.py               # Data protocol\n   └── sample.py                 # Sampling logic\n\n**Responsibilities:**\n\n- Distributed data buffering (per-server)\n- Data loading (per-GPU)\n- Data redistribution and load balancing\n\nparams/\n~~~~~~~\n\nParameter configuration using Hydra.\n\n.. code-block:: text\n\n   params/\n   ├── __init__.py               # SiiRLArguments\n   ├── parser.py                 # Configuration parser\n   ├── data_args.py              # Data parameters\n   ├── model_args.py             # Model parameters\n   ├── training_args.py          # Training parameters\n   ├── dag_args.py               # DAG parameters\n   ├── embodied_args.py          # Embodied parameters\n   └── profiler_args.py          # Profiler parameters\n\nenvironment/\n~~~~~~~~~~~~\n\nEnvironment abstraction for Embodied AI and multi-agent systems.\n\n.. code-block:: text\n\n   environment/\n   └── embodied/\n       ├── base.py               # Environment base class\n       ├── venv.py               # Vectorized environment\n       └── adapters/             # Environment adapters\n           └── libero.py         # Libero adapter\n\nuser_interface/\n~~~~~~~~~~~~~~~\n\nUser-defined interfaces.\n\n.. code-block:: text\n\n   user_interface/\n   ├── filter_interface/\n   │   ├── dapo.py               # DAPO dynamic sampling\n   │   └── embodied.py           # Embodied data filtering\n   └── rewards_interface/\n       └── custom_gsm8k_reward.py  # Custom reward example\n\n**Purpose:** Provides interfaces for user-defined node functions.\n\nData Structures\n---------------\n\nNodeOutput\n~~~~~~~~~~\n\nReturn value from node execution.\n\n.. code-block:: python\n\n   @dataclass\n   class NodeOutput:\n       batch: Any              # Data batch\n       metrics: Dict = None    # Metrics\n       info: Dict = None       # Additional info\n\nNode\n~~~~\n\nDAG node definition.\n\n.. code-block:: python\n\n   @dataclass\n   class Node:\n       node_id: str                    # Node ID\n       node_type: NodeType             # Node type\n       node_role: NodeRole             # Node role\n       dependencies: List[str]         # Dependency nodes\n       executable: Callable            # Executable function\n       executable_ref: str             # Function path\n       only_forward_compute: bool      # Forward only\n\nEnumerations\n~~~~~~~~~~~~\n\n**NodeType:**\n\n.. code-block:: python\n\n   class NodeType(Enum):\n       MODEL_INFERENCE = \"model_inference\"\n       MODEL_TRAIN = \"model_train\"\n       COMPUTE = \"compute\"\n       DATA_LOAD = \"data_load\"\n\n**NodeRole:**\n\n.. code-block:: python\n\n   class NodeRole(Enum):\n       ROLLOUT = \"rollout\"\n       ACTOR = \"actor\"\n       CRITIC = \"critic\"\n       REFERENCE = \"reference\"\n       REWARD = \"reward\"\n       ADVANTAGE = \"advantage\"\n       DYNAMIC_SAMPLING = \"dynamic_sampling\"\n       DEFAULT = \"default\"\n\n**AdvantageEstimator:**\n\n.. code-block:: python\n\n   class AdvantageEstimator(Enum):\n       GRPO = \"grpo\"\n       GAE = \"gae\"\n       CPGD = \"cpgd\"\n       GSPO = \"gspo\"\n\n**WorkflowType:**\n\n.. code-block:: python\n\n   class WorkflowType(Enum):\n       DEFAULT = \"DEFAULT\"\n       DAPO = \"DAPO\"\n       EMBODIED = \"EMBODIED\"\n\nExecution Flow\n--------------\n\nStartup Flow (main_dag.py)\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: text\n\n   1. Parse configuration (parse_config)\n   2. Load Pipeline (load_pipeline)\n   3. Initialize DataBuffer (init_data_coordinator)\n   4. Initialize MetricWorker\n   5. Task scheduling (TaskScheduler)\n   6. Launch Ray cluster (RayTrainer)\n   7. Create DAGWorker (one per GPU)\n   8. Execute training (DAGWorker.execute_task_graph)\n\nDAGWorker Execution Flow\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: text\n\n   1. Initialize Workers (Actor/Critic/Rollout/Reference/Reward)\n   2. Initialize DataLoader\n   3. Initialize Validator\n   4. Load Checkpoint (if exists)\n   5. Training loop:\n      - Load data\n      - Execute nodes in topological order\n      - Collect metrics\n      - Save Checkpoint\n      - Validate (if needed)\n\nNode Execution Flow\n~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: text\n\n   1. DAGWorker gets node's executable function\n   2. Call function with current batch\n   3. Function processes data, returns NodeOutput\n   4. Update batch, pass to next node\n   5. Collect node metrics\n\nKey Concepts\n------------\n\nTaskGraph\n~~~~~~~~~\n\nDirected Acyclic Graph representing training workflow.\n\n**Core Methods:**\n\n- ``add_node()``: Add node\n- ``build_adjacency_lists()``: Build adjacency lists\n- ``validate_graph()``: Validate DAG\n- ``get_execution_order()``: Get topological sort\n\nPipeline\n~~~~~~~~\n\nDeclarative API for building TaskGraph.\n\n**Core Methods:**\n\n- ``add_node()``: Add node (supports chaining)\n- ``build()``: Build and validate TaskGraph\n\nDAGWorker Class\n~~~~~~~~~~~~~~~\n\nExecution unit per GPU.\n\n**Core Methods:**\n\n- ``generate()``: Rollout generation\n- ``compute_reward()``: Compute reward\n- ``compute_advantage()``: Compute advantage\n- ``compute_old_log_prob()``: Old policy log prob\n- ``compute_ref_log_prob()``: Reference model log prob\n- ``compute_value()``: Value function (PPO)\n- ``train_actor()``: Train actor\n- ``train_critic()``: Train critic (PPO)\n\nConfiguration Parameters\n------------------------\n\nMain Configuration Groups\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: yaml\n\n   algorithm:\n     adv_estimator: grpo  # grpo/gae/cpgd/gspo\n     workflow_type: DEFAULT  # DEFAULT/DAPO/EMBODIED\n\n   data:\n     train_files: /path/to/train.parquet\n     train_batch_size: 512\n     max_prompt_length: 2048\n     max_response_length: 4096\n\n   actor_rollout_ref:\n     model:\n       path: /path/to/model\n     actor:\n       optim:\n         lr: 1e-6\n       ppo_mini_batch_size: 256\n     rollout:\n       name: vllm  # vllm/sglang/hf\n       tensor_model_parallel_size: 2\n       n: 8  # GRPO group size\n\n   trainer:\n     n_gpus_per_node: 8\n     nnodes: 1\n     total_epochs: 30\n     save_freq: 10\n\n   dag:\n     custom_pipeline_fn: null  # Custom Pipeline\n\nExtension Points\n----------------\n\nCustom Pipeline\n~~~~~~~~~~~~~~~\n\nAdd new functions in ``siirl/execution/dag/builtin_pipelines.py``.\n\nCustom Node Functions\n~~~~~~~~~~~~~~~~~~~~~\n\nImplement functions following the signature:\n\n.. code-block:: python\n\n   def my_node(batch, config=None, **kwargs) -> NodeOutput:\n       return NodeOutput(batch=batch, metrics={})\n\nCustom Reward Manager\n~~~~~~~~~~~~~~~~~~~~~\n\nAdd new classes in ``siirl/engine/reward_manager/``.\n\nCustom Environment\n~~~~~~~~~~~~~~~~~~\n\nAdd new environment classes in ``siirl/environment/``.\n"
  },
  {
    "path": "docs/programming_guide/siiRL_code_explained.rst",
    "content": "siiRL's Implementation Explained\n================================\n\nsiiRL is under active development with an extensive roadmap for future enhancements. We strongly encourage community participation in this endeavor. Contributions in any form are highly valued, including but not limited to: filing issues, proposing new features, enhancing documentation, and providing suggestions for improvement.\n\nOverall Implementation\n----------------------\n\nRL training itself has clear workflow characteristics, and DAG is the mainstream tool for describing workflows. Therefore, the source code of siiRL adopts a DAG-based design pattern. In terms of specific implementation, siiRL abstracts the entire RL training task into a TaskGraph composed of multiple Nodes, each of which implements the ``node.run()`` method to support the abstract orchestration of the top-level TaskGraph. The constructed TaskGraph is submitted to a set of DAGWorkers for execution.\n\nIn the context of multi-agent RL training, different DAGWorkers can process different TaskGraphs in parallel, and the data that different TaskGraphs depend on and process may also vary. Therefore, from a structural perspective, siiRL belongs to the MPMD paradigm.\n\nIn terms of user usage, in addition to the configurations related to Data/Trainer/Model/RL Algorithm used by mainstream RL frameworks, siiRL also provides DAG config, which supports users to customize workflows. The system will parse the DAG configuration when the training starts and correspondingly construct a TaskGraph instance.\n\nComplex task workflow poses higher requirements for resource scheduling. To achieve fine-grained allocation of GPUs, siiRL implements a set of TaskScheduler, which is responsible for making globally optimal scheduling decisions, such as: how much computing resources to allocate to each TaskGraph, and specifically which GPU devices on which servers to use. Finally, the allocation plan generated by TaskScheduler is handed over to the underlying Ray framework for specific execution, making full use of Ray's distributed computing capabilities.\n\n.. figure:: ../../asset/code_explained/siirl_arch.png\n   :width: 60%\n   :align: center\n   :alt: Overall Architecture of siiRL's Code Implementation\n\n   Figure 1: Overall Architecture of siiRL\n\nWe will first provide an overview diagram of the siiRL source code implementation, and then, in the following text, we will introduce each part of the diagram in detail according to the actual execution process.\n\n.. figure:: ../../asset/code_explained/overview_diagram.png\n   :width: 100%\n   :align: center\n   :alt: Diagram of Source Code Implementation\n\n   Figure 2: Diagram of Source Code Implementation\n\nEnvironment Abstraction\n-----------------------\n\nDuring initial RL stage of LLMs, the environment typically refers to the datasets used in post-training. siiRL abstracts the concept of environment to uniformly support RL tasks in different application areas, such as MCP calls and SandBox Server in agentic training scenarios, as well as simulators in the embodied AI domain, or real physical environments for agent interaction.\n\nSimilar to OpenAI Gym, siiRL defines two core asynchronous methods:\n\n- ``reset()``: Resets the environment to its initial state and returns the initial observation. This function marks the start of a new episode.\n- ``step(actions)``: Receives actions from one or multiple agents, executes these actions, updates the environment state, and returns a tuple containing (observation at the next time step, reward, information). This is the main loop for agent-environment interaction.\n\nTaking the MathEnv of mathematical tasks as an example, the environment natively supports multiple agents. The step function receives a complex number of actions, and the returned observations are also an array prepared for each agent.\n\n.. code-block:: python\n\n   class MathEnv(BaseEnvironment):\n       async def reset(self, dp_rank: int, ddp_world_size: int, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None):\n           # ...\n           obs = np.array([self.current_state for _ in range(self.n_agents)], dtype=np.object_)\n           self.step_count = 0\n           return obs\n       \n       async def step(self, actions):\n           # ...\n           return next_obs, rewards, infos\n\nControl Flow: Pipeline\n----------------------\n\nThe main pipeline of the siiRL control flow is shown in the figure below. First, load the configuration of the interactive environment, then sequentially complete the initialization of DataBuffer, the loading and parsing of DAG configuration, and the construction of TaskGraph. After the TaskGraph is constructed, the TaskScheduler schedules (makes decisions on) tasks, determining how many GPUs to allocate to each task and calculating the specific allocation topology. Then, use Ray to construct a distributed process group and initialize RayTrainer. Finally, initialize DAGWorker (Ray's Actor) and start the training task.\n\n.. figure:: ../../asset/code_explained/pipeline.png\n   :width: 40%\n   :align: center\n   :alt: Pipeline of Control Flow\n\n   Figure 3: Pipeline of Control Flow\n\nDataLoader and DataBuffer\n-------------------------\n\nDataLoader is a wrapper for torch's StatefulDataLoader, which, in combination with the custom PartitionedRLHFDataset, is responsible for tasks such as loading, preprocessing, and batching of training data. Different from other RL open-source frameworks, DataLoader in siiRL is also abstracted as a Node (DataLoaderNode) and embedded into the TaskGraph for execution. Under normal cluster scale and RL tasks, siiRL launches a data_loader process for each GPU rank, which is responsible for loading the data shard corresponding to the DAGWorker on the current rank.\n\n.. code-block:: python\n\n   class DataLoaderNode(Node):\n       \"\"\"\n       Represents a data loader node in the DAG.\n       This version uses the PartitionedRLHFDataset for efficient, memory-safe\n       distributed data loading. Each rank only loads and processes its own data slice.\n       \"\"\"\n       def run(self, epoch: Optional[int] = None, is_validation_step: bool = False, **kwargs: Any) -> Any:\n           \"\"\"\n           Executes the data loading process for a given step or validation.\n           \"\"\"\n           try:\n               # for validation\n               if is_validation_step:\n                   try:\n                       batch = next(self._current_val_iter)\n               # for training\n               else:  \n                   try:\n                       batch = next(self._current_train_iter)\n               return batch\n\nDataBuffer is essentially a distributed KV Store, maintained by an independent Ray Actor process. Typically, DataLoader is per-gpu, while DataBuffer is per-server. In static batching mode, siiRL checks the load balance when creating DataBuffer, as shown in the figure below. For example, if the training batch size is 128, it needs to be divisible by the number of servers to ensure that a global batch can be evenly distributed among servers. Similarly, the batch size allocated to a server, after being replicated ``n`` times (the group size in GRPO, or ``n = 1`` if it is PPO), also needs to be divisible by 8 to ensure that it can be evenly distributed among GPUs on the same server.\n\n.. figure:: ../../asset/code_explained/data_loader.png\n   :width: 95%\n   :align: center\n   :alt: Diagram of Source Code Implementation\n\n   Figure 4: DataLoader, DataBuffer and Load Balance\n\nTaskGraph Scheduling\n--------------------\n\nThe core of TaskGraph is a dictionary composed of Nodes, and TaskGraph uses adjacency lists and reverse adjacency lists to represent the connection relationships between these Nodes. Among them, the reverse adjacency list is mainly used for dependency checking, such as Actor's training depending on rollout's generation. Meanwhile, TaskGraph provides a series of graph operation methods, such as adding, deleting, modifying, and querying nodes, DAG verification, copying, and displaying the graph, to implement the management of TaskGraph.\n\n.. code-block:: python\n\n   class Node:\n       \"\"\"\n       Represents a node (task unit) in the DAG.\n       \"\"\"\n       \n   class TaskGraph:\n       \"\"\"\n       Represents a Directed Acyclic Graph (DAG) of tasks, \n       composed of multiple Node objects and their dependencies.\n       \"\"\"\n\n       def __init__(self, graph_id: str):\n           \"\"\"\n           Initialize a task graph.\n           Parameters:\n               graph_id (str): The unique identifier of the graph.\n           \"\"\"\n           self.graph_id: str = graph_id\n           self.nodes: Dict[str, Node] = {} \n           self.adj: Dict[str, List[str]] = {}\n           self.rev_adj: Dict[str, List[str]] = {}\n\nThe scheduling of TaskGraph includes four key steps:\n\n1. **TaskGraph Splitting**: When a user-defined workflow contains parallel paths—as seen in multi-agent training where agents use both shared and specific Nodes—siiRL splits the original TaskGraph into multiple subgraphs for sequential execution. While this approach may not be the most efficient, it significantly simplifies resource scheduling.\n\n2. **SubGraph Sorting**: To allocate resources reasonably, siiRL sorts all SubGraphs. The sorting is mainly based on two points. First, the size of the SubGraph, where this size refers to the parameter scale of the model to be trained on the current SubGraph (7B, 32B, 671B, etc.), with priority given to resource allocation for SubGraphs with larger parameter scales. Second, the number of Nodes on the SubGraph; the more Nodes, i.e., the \"longer the chain\" of the SubGraph, the earlier it is allocated.\n\n3. **GPU Quota Allocation**: Based on the sorting results from Step 2, allocate the number of GPUs to each SubGraph. There are two allocation strategies: even and param_aware. In the even mode, the total number of GPUs is evenly distributed among SubGraphs as much as possible; in the param_aware mode, on the premise that each subgraph is allocated at least one GPU, subgraphs with larger sizes are allocated more GPUs as much as possible.\n\n4. **GPU Topology Allocation**: With the allocation of the number of GPUs in Step 3, this step performs topology allocation. Suppose there are three SubGraphs, denoted as sg1, sg2, sg3, the training cluster consists of 2 machines with 16 GPUs, and the allocation result regarding the number in Step 3 is: (6, 5, 5), this step will determine \"specifically, which 6 GPUs are allocated to sg1, which 6 to sg2, and finally, which 5 to sg3\". siiRL makes decisions through a scoring mechanism:\n\n   ``(cohesion_score(+), node_load_score(-), rank_preference_score(-))``\n\n   Where: ``cohesion_score`` is the cohesion score: place a subgraph within the same server as much as possible to reduce communication; ``node_load_score`` is the load penalty: balance placement among servers as much as possible; ``rank_preference_score`` represents the rank partial order: place tasks on GPUs with smaller rank numbers as much as possible to make the scheduling behavior more predictable.\n\n.. figure:: ../../asset/code_explained/taskgraph_sched.png\n   :width: 95%\n   :align: center\n   :alt: TaskGraph Scheduling\n\n   Figure 5: TaskGraph Scheduling\n\nBuild the Distributed Process Group\n-----------------------------------\n\nAfter task scheduling is completed, the distributed process group of Ray can be constructed. According to the topology determined by the above scheduling, construct the affiliated process group for each Node of the TaskGraph.\n\nFor example, actor's training (described as ``NodeRole=Actor, NodeType=Train`` in siiRL), if the assigned ranks are ``[0, 1, 2, 3, 4, 5]``, then use Python's Tuple as the key and a unique string as the value for naming: ``(0,1,2,3,4,5): \"process_group_1\"``\n\n.. figure:: ../../asset/code_explained/dist_pg.png\n   :width: 95%\n   :align: center\n   :alt: Distributed Process Group\n\n   Figure 6: Distributed Process Group\n\nRay Trainer\n-----------\n\nAfter constructing the process group, initialize RayTrainer. This part is similar to the practices of other mainstream frameworks, with the core being the instantiation of Ray's resource pool management, i.e., resource_manager. Finally, collectively validate the configurations of all Nodes (Actor/Rollout/Reward, etc.).\n\n.. figure:: ../../asset/code_explained/ray_trainer.png\n   :width: 95%\n   :align: center\n   :alt: Ray Trainer\n\n   Figure 7: Ray Trainer\n\nDAGWorker\n---------\n\nThrough a series of abstractions regarding DAG and TaskGraph, siiRL encapsulates and hides the training job flow beneath the control flow. The call logic related to training backend, inference backend, sharding manager, etc., which is directly visible in the control flow of veRL, is all encapsulated into DAGWorker in siiRL and is almost invisible in the control flow. In terms of programming mode, this hiding provides a higher level of abstraction, offering more convenient modular reuse and more flexible extensibility compared to other mainstream frameworks, but it may additionally increase the complexity of bug localization.\n\nIn terms of source code implementation, DAGWorker uses mixin classes for modularization. The core mixin classes include 5, which are responsible for initialization, pipeline execution, execution of specific Nodes, training validation, and utility functions, respectively, as shown below.\n\n.. figure:: ../../asset/code_explained/dag_worker.png\n   :width: 70%\n   :align: left\n   :alt: DAG Worker\n\nWhen initializing DAGWorker, first call resource_manager (the one created during RayTrainer initialization) to create ResourcePool, then create RayActorManager to manage the lifecycle of all distributed DAGWorkers. Finally, call the method defined in the InitializationMixin mixin class to complete the initialization of DAGWorker.\n\n.. figure:: ../../asset/code_explained/dag_init.png\n   :width: 80%\n   :align: center\n   :alt: Initialization of DAG Worker\n\n   Figure 8: Initialization of DAG Worker\n\nWhen setting up the communication group, siiRL adopts the following strategy: if the total number of ranks is less than 256, it uses the pure NCCL backend; otherwise, it uses the GLOO+NCCL hybrid backend. In the hybrid backend mode, GLOO is mainly used for aggregated communication of data such as logs and metrics.\n\nTraining Initiation\n-------------------\n\nThe main pipeline initiates training in the final step. Here, it primarily calls the ``execute_task_graph`` method in the ExecutionMixin mixin class. This method encapsulates the outer loop of epochs and the inner loop of batches within each epoch (i.e., a training step).\n\n.. figure:: ../../asset/code_explained/train_init.png\n   :width: 70%\n   :align: center\n   :alt: Training Job Initialization\n\n   Figure 9: Training Job Initialization\n\nEach training step is no longer \"concrete and expanded\", as in mainstream frameworks such as veRL, but rather \"abstract and cyclic\": traverse all Nodes in the Graph, for each Node, execute the run method, and write the resulting data to the DataBuffer, where the key is the node_id of the next node and the value is the output of the run method.\n\n.. figure:: ../../asset/code_explained/data_buffer_loop.png\n   :width: 70%\n   :align: center\n   :alt: Loop of TaskGraph Computation based on DataBuffer\n\n   Figure 10: Loop of TaskGraph Computation based on DataBuffer\n\n"
  },
  {
    "path": "docs/programming_guide/siirl_architecture_guide.rst",
    "content": "=======================================\nsiiRL Complete Architecture Guide\n=======================================\n\n.. note::\n   **Target Audience**: This document assumes no prior knowledge of siiRL, but expects basic familiarity with Python, PyTorch, and reinforcement learning concepts.\n   We will systematically explain siiRL's design philosophy, architecture implementation, and core algorithms from the ground up.\n\nTable of Contents\n=================\n\n- :ref:`sec1_overview`\n- :ref:`sec2_design_philosophy`\n- :ref:`sec3_main_entry`\n- :ref:`sec4_dag_planner`\n- :ref:`sec5_dag_worker`\n- :ref:`sec6_data_coordinator`\n- :ref:`sec7_engine`\n- :ref:`sec8_core_algorithms`\n- :ref:`sec9_execution_flow`\n- :ref:`sec10_configuration`\n- :ref:`sec11_extension_guide`\n\n----\n\n.. _sec1_overview:\n\n1. siiRL Architecture Overview\n==============================\n\n1.1 What is siiRL?\n------------------\n\n**siiRL** (Shanghai Innovation Institute RL Framework) is a novel **fully distributed reinforcement learning framework** designed to break the scaling barriers in LLM post-training. By eliminating the centralized controller common in other frameworks, siiRL achieves:\n\n- **Near-Linear Scalability**: The multi-controller paradigm eliminates central bottlenecks by distributing control logic and data management across all workers\n- **SOTA Throughput**: Fully distributed dataflow architecture minimizes communication and I/O overhead\n- **Flexible DAG-Defined Pipeline**: Decouples algorithmic logic from physical hardware, enabling rapid experimentation\n\n1.2 System Architecture and Data Flow\n-------------------------------------\n\n**System Architecture Diagram**:\n\n.. figure:: https://github.com/sii-research/siiRL/raw/main/asset/overview.png\n   :width: 100%\n   :alt: siiRL Architecture Overview\n   :align: center\n   \n   **Figure 1.1**: siiRL System Architecture showing the three core components: DAG Planner, DAG Workers, and Data Coordinator\n\n**Complete Training Step Sequence Diagram**:\n\nThe following sequence diagram shows the complete data flow for a single GRPO training step:\n\n::\n\n      User          MainRunner       DAGWorker      DataCoordinator     Engine\n     (YAML)         (Planner)       (per GPU)        (Singleton)       Workers\n        |               |               |                 |               |\n   ============================================================================\n   | INITIALIZATION PHASE                                                     |\n   ============================================================================\n        |               |               |                 |               |\n        | 1. Config     |               |                 |               |\n        |-------------->|               |                 |               |\n        |               |               |                 |               |\n        |               | 2. load_pipeline() + TaskScheduler.schedule()   |\n        |               |------------------------------------------------>|\n        |               |               |                 |               |\n        |               | 3. Create DAGWorkers (one per GPU)              |\n        |               |-------------->|                 |               |\n        |               |               |                 |               |\n        |               |               | 4. init_graph() |               |\n        |               |               |    Load models  |               |\n        |               |               |-------------------------------->|\n        |               |               |                 |               |\n   ============================================================================\n   | TRAINING LOOP (per step)                                                 |\n   ============================================================================\n        |               |               |                 |               |\n        |               |               | 5. DataLoader   |               |\n        |               |               |    .run()       |               |\n        |               |               |<----------------|               |\n        |               |               | batch (prompts) |               |\n        |               |               |                 |               |\n        |               |               | 6. Node: rollout_actor          |\n        |               |               |-------------------------------->|\n        |               |               |     Rollout.generate_sequences()|\n        |               |               |<--------------------------------|\n        |               |               | batch + responses               |\n        |               |               |                 |               |\n        |               |               | 7. Node: function_reward        |\n        |               |               |    compute_reward()             |\n        |               |               |---------------->|               |\n        |               |               | batch + scores  |               |\n        |               |               |                 |               |\n        |               |               | 8. Node: calculate_advantages   |\n        |               |               |    compute_advantage()          |\n        |               |               |    (GRPO group normalization)   |\n        |               |               |                 |               |\n        |               |               | 9. put_data_to_buffers()        |\n        |               |               |    (if DP size changes)         |\n        |               |               |---------------->|               |\n        |               |               |                 | ray.put()     |\n        |               |               |                 |               |\n        |               |               | 10. get_data_from_buffers()     |\n        |               |               |<----------------|               |\n        |               |               | redistributed batch             |\n        |               |               |                 |               |\n        |               |               | 11. Node: actor_old_log_prob    |\n        |               |               |-------------------------------->|\n        |               |               |     Actor.compute_log_prob()    |\n        |               |               |<--------------------------------|\n        |               |               | batch + old_log_probs           |\n        |               |               |                 |               |\n        |               |               | 12. Node: reference_log_prob    |\n        |               |               |-------------------------------->|\n        |               |               |   Reference.compute_ref_log_prob|\n        |               |               |<--------------------------------|\n        |               |               | batch + ref_log_probs           |\n        |               |               |                 |               |\n        |               |               | 13. Node: actor_train           |\n        |               |               |-------------------------------->|\n        |               |               |     Actor.update_actor()        |\n        |               |               |     - Forward pass              |\n        |               |               |     - Compute policy loss       |\n        |               |               |     - Backward pass             |\n        |               |               |     - Optimizer step            |\n        |               |               |<--------------------------------|\n        |               |               | metrics                         |\n        |               |               |                 |               |\n        |               |               | 14. sync_weights_actor_to_rollout\n        |               |               |-------------------------------->|\n        |               |               |     ShardingManager.sync()      |\n        |               |               |                 |               |\n        |               |               | 15. Log metrics + checkpoint    |\n        |               |               |                 |               |\n   ============================================================================\n   | REPEAT for next training step                                            |\n   ============================================================================\n\n**Data Flow Summary**:\n\n::\n\n                              GRPO Single Step Data Flow\n   ==============================================================================\n   \n   DataLoader\n       |\n       | batch: {prompts, attention_mask, index}\n       v\n   +---------------------+\n   | rollout_actor       | DAGWorker.generate()\n   | (MODEL_INFERENCE)   | -> Rollout.generate_sequences()\n   +----------+----------+\n              | + {responses, response_ids, response_mask}\n              v\n   +---------------------+\n   | function_reward     | DAGWorker.compute_reward()\n   | (COMPUTE)           | -> RewardManager.compute_reward()\n   +----------+----------+\n              | + {token_level_scores, token_level_rewards}\n              v\n   +---------------------+\n   | calculate_advantages| DAGWorker.compute_advantage()\n   | (COMPUTE)           | -> compute_grpo_outcome_advantage()\n   +----------+----------+ Group by prompt -> Normalize (score - mean)/std\n              | + {advantages}\n              v\n   +---------------------+\n   | actor_old_log_prob  | DAGWorker.compute_old_log_prob()\n   | (MODEL_TRAIN)       | -> Actor.compute_log_prob()\n   | only_forward=True   |\n   +----------+----------+\n              | + {old_log_probs}\n              v\n   +---------------------+\n   | reference_log_prob  | DAGWorker.compute_ref_log_prob()\n   | (MODEL_TRAIN)       | -> Reference.compute_ref_log_prob()\n   +----------+----------+\n              | + {ref_log_prob}\n              v\n   +---------------------+\n   | actor_train         | DAGWorker.train_actor()\n   | (MODEL_TRAIN)       | -> Actor.update_actor()\n   +----------+----------+ policy_loss = -advantages * clip(ratio)\n              |\n              | metrics: {loss, clipfrac, kl, lr, ...}\n              v\n   +---------------------+\n   | sync_weights        | ShardingManager.sync_weights_actor_to_rollout()\n   +---------------------+                                            \n\n1.3 Core Component Responsibilities\n-----------------------------------\n\n.. list-table:: siiRL Core Components\n   :header-rows: 1\n   :widths: 20 20 60\n\n   * - Component\n     - Process/Actor\n     - Core Responsibilities\n   * - **DAG Planner**\n     - MainRunner Actor\n     - Parse user-defined DAG workflows, generate execution plans, assign tasks to workers\n   * - **DAG Worker**\n     - One Actor per GPU\n     - Core execution unit responsible for model initialization, task execution, data flow management\n   * - **Data Coordinator**\n     - Global Singleton Actor\n     - Manage distributed data lifecycle including data loading and intermediate data redistribution\n   * - **TaskScheduler**\n     - Inside MainRunner\n     - Split and assign TaskGraph to each DAG Worker\n   * - **ProcessGroupManager**\n     - Inside MainRunner\n     - Manage creation and configuration of distributed communication groups (TP/PP/DP)\n   * - **MetricWorker**\n     - Standalone Actor\n     - Distributed metrics collection and aggregation\n\n1.4 Why is siiRL Different?\n---------------------------\n\n**Problems with Traditional Frameworks**:\n\n1. **Single-Controller Bottleneck**: All data flows through a single node, causing I/O and communication overhead\n2. **Rigid Algorithm Pipelines**: Modifying workflows requires deep modifications to framework source code\n\n**siiRL's Solutions**:\n\n.. list-table:: siiRL Design Advantages\n   :header-rows: 1\n   :widths: 25 35 40\n\n   * - Traditional Frameworks\n     - siiRL DistFlow\n     - Advantage\n   * - Centralized Controller\n     - Multi-Controller Paradigm\n     - Eliminates single-point bottleneck, near-linear scaling\n   * - Hard-coded Workflows\n     - DAG-Defined Pipeline\n     - Declarative configuration, no code modification needed\n   * - Centralized Data Management\n     - Distributed Data Coordinator\n     - Avoids OOM, parallelizes data loading\n\n----\n\n.. _sec2_design_philosophy:\n\n2. DistFlow Design Philosophy\n=============================\n\n2.1 Fully Distributed Architecture\n----------------------------------\n\nThe core idea of DistFlow is **\"no central coordinator\"**. Each DAG Worker is an independent execution unit with its own:\n\n- Data loader (partitioned dataset)\n- Model instances (Actor/Critic/Rollout/Reference/Reward)\n- Task execution graph (subgraph of TaskGraph)\n- Local data cache\n\n2.2 Three-Layer Architecture Design\n-----------------------------------\n\n::\n\n   ┌─────────────────────────────────────────────────────────────────┐\n   │                     User Configuration Layer (YAML/Python)      │\n   │   - workflow_grpo.yaml: Define algorithm DAG                    │\n   │   - config.yaml: Model, data, training parameters               │\n   └─────────────────────────────────────────────────────────────────┘\n                                    │\n                                    ▼\n   ┌─────────────────────────────────────────────────────────────────┐\n   │                     Execution Scheduling Layer (DAG Planner)     │\n   │   - TaskScheduler: Task assignment                              │\n   │   - ProcessGroupManager: Communication group management          │\n   │   - GraphUpdater: Configuration injection                       │\n   └─────────────────────────────────────────────────────────────────┘\n                                    │\n                                    ▼\n   ┌─────────────────────────────────────────────────────────────────┐\n   │                     Distributed Execution Layer (DAG Workers)    │\n   │   ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐       │\n   │   │Worker 0  │  │Worker 1  │  │Worker 2  │  │Worker N  │       │\n   │   │ (GPU 0)  │  │ (GPU 1)  │  │ (GPU 2)  │  │ (GPU N)  │       │\n   │   └──────────┘  └──────────┘  └──────────┘  └──────────┘       │\n   └─────────────────────────────────────────────────────────────────┘\n                                    │\n                                    ▼\n   ┌─────────────────────────────────────────────────────────────────┐\n   │                     Data Coordination Layer (Data Coordinator)   │\n   │   - Distributed DataLoader: Partitioned data loading            │\n   │   - Distributed DataBuffer: Intermediate data redistribution    │\n   └─────────────────────────────────────────────────────────────────┘\n\n2.3 Core Design Principles\n--------------------------\n\n.. list-table:: DistFlow Design Principles\n   :header-rows: 1\n   :widths: 25 75\n\n   * - Principle\n     - Description\n   * - **Worker Autonomy**\n     - Each DAG Worker is a fully independent execution unit, not dependent on central coordination\n   * - **Data Locality**\n     - Data is processed locally as much as possible, reducing cross-node transfers\n   * - **Declarative Workflows**\n     - Algorithm logic is declared via DAG, decoupled from execution engine\n   * - **Unified Sample Protocol**\n     - All intermediate data uses Sample/SampleInfo protocol, supporting flexible routing\n   * - **Late Binding**\n     - Configuration is injected into nodes at runtime, supporting dynamic adjustment\n\n----\n\n.. _sec3_main_entry:\n\n3. Program Entry and Startup Flow\n=================================\n\n3.1 main_dag.py Explained\n-------------------------\n\n``main_dag.py`` is the entry point of siiRL, but unlike traditional frameworks, its role is a **launcher** rather than an executor.\n\n.. code-block:: python\n   :caption: siirl/main_dag.py Core Structure\n\n   def main() -> None:\n       \"\"\"Main entry: Initialize Ray cluster, parse config, start MainRunner\"\"\"\n       \n       # 1. Initialize Ray cluster\n       if not ray.is_initialized():\n           ray.init(runtime_env={\"env_vars\": RAY_RUNTIME_ENV_VARS})\n       \n       # 2. Parse configuration\n       siirl_args = parse_config()\n       \n       # 3. Start main orchestration Actor\n       runner = MainRunner.remote()\n       ray.get(runner.run.remote(siirl_args))\n\n3.2 MainRunner Actor\n--------------------\n\n``MainRunner`` is the \"brain\" of the system, responsible for orchestrating the entire training workflow:\n\n.. code-block:: python\n   :caption: MainRunner.run() Core Flow\n\n   @ray.remote(num_cpus=MAIN_RUNNER_CPU_RESERVATION)\n   class MainRunner:\n       def run(self, siirl_args: SiiRLArguments) -> None:\n           # 1. Initialize DataCoordinator\n           data_coordinator_handle = init_data_coordinator(\n               num_buffers=siirl_args.trainer.nnodes,\n               ppo_mini_batch_size=siirl_args.actor_rollout_ref.actor.ppo_mini_batch_size,\n               world_size=siirl_args.trainer.nnodes * siirl_args.trainer.n_gpus_per_node\n           )\n           \n           # 2. Load and configure workflow DAG\n           workflow_taskgraph = load_pipeline(siirl_args)\n           update_task_graph_node_configs(workflow_taskgraph, siirl_args)\n           \n           # 3. Schedule tasks to each worker\n           task_scheduler = TaskScheduler(siirl_args.trainer.nnodes, \n                                          siirl_args.trainer.n_gpus_per_node)\n           rank_taskgraph_mapping = task_scheduler.schedule_and_assign_tasks([workflow_taskgraph])\n           \n           # 4. Create process groups\n           process_group_manager = ProcessGroupManager(total_workers, rank_taskgraph_mapping)\n           \n           # 5. Create metric worker\n           metric_worker_handle = MetricWorker.remote()\n           \n           # 6. Initialize and start DAG Workers\n           trainer = RayTrainer(config=siirl_args, ...)\n           trainer.init_workers()\n           trainer.start_workers()\n\n3.3 Startup Flow Sequence Diagram\n---------------------------------\n\n::\n\n   main()\n      │\n      ├── ray.init()                          ← Initialize Ray cluster\n      │\n      ├── parse_config()                      ← Parse YAML configuration\n      │\n      └── MainRunner.run()\n              │\n              ├── init_data_coordinator()     ← Create global DataCoordinator\n              │\n              ├── load_pipeline()             ← Load DAG definition\n              │       │\n              │       └── grpo_pipeline()     ← Return TaskGraph\n              │\n              ├── TaskScheduler.schedule()    ← Assign tasks to each rank\n              │\n              ├── ProcessGroupManager()       ← Create communication group specs\n              │\n              ├── RayTrainer.init_workers()   ← Create DAG Worker Actors\n              │       │\n              │       └── DAGWorker.__init__() × N_workers\n              │\n              └── RayTrainer.start_workers()  ← Start training loop\n                      │\n                      └── DAGWorker.execute_task_graph() × N_workers\n\n----\n\n.. _sec4_dag_planner:\n\n4. DAG Planner Deep Dive\n========================\n\nThe DAG Planner is siiRL's \"scheduling brain\", responsible for converting user-defined high-level workflows into executable distributed tasks.\n\n**Pipeline Architecture Overview**:\n\nThe following diagram shows how the core data structures relate to each other and how a Pipeline is built and executed:\n\n::\n\n                           Pipeline Data Structure Relationships\n   ==============================================================================\n   \n                                 +------------------+\n                                 |    Pipeline      |\n                                 |    (Builder)     |\n                                 +------------------+\n                                 | - pipeline_id    |\n                                 | - description    |\n                                 | - _nodes: Dict   |\n                                 +--------+---------+\n                                          |\n                                          | .build()\n                                          v\n                                 +------------------+\n                                 |   TaskGraph      |\n                                 |     (DAG)        |\n                                 +------------------+\n                                 | - graph_id       |\n                                 | - nodes: Dict    |\n                                 | - adj: Dict      |\n                                 | - rev_adj: Dict  |\n                                 +--------+---------+\n                                          |\n                                          | contains multiple\n                                          v\n         +----------------+    +----------------+    +----------------+\n         |     Node       |    |     Node       |    |     Node       |  ...\n         +----------------+    +----------------+    +----------------+\n         | - node_id      |    | - node_id      |    | - node_id      |\n         | - node_type    |    | - node_type    |    | - node_type    |\n         | - node_role    |    | - node_role    |    | - node_role    |\n         | - dependencies |    | - dependencies |    | - dependencies |\n         | - executable   |    | - executable   |    | - executable   |\n         | - config       |    | - config       |    | - config       |\n         +----------------+    +----------------+    +----------------+\n   \n   ==============================================================================\n   \n   NodeType (from node.py)             NodeRole (from node.py)\n   +------------------------+          +------------------------+\n   | COMPUTE                |          | DEFAULT                |\n   | DATA_LOAD              |          | ACTOR                  |\n   | ENV_INTERACT           |          | ADVANTAGE              |\n   | MODEL_INFERENCE        |          | CRITIC                 |\n   | MODEL_TRAIN            |          | ROLLOUT                |\n   | PUT_TO_BUFFER          |          | REFERENCE              |\n   | GET_FROM_BUFFER        |          | REWARD                 |\n   | BARRIER_SYNC           |          | DYNAMIC_SAMPLING       |\n   | CUSTOM                 |          +------------------------+\n   +------------------------+\n\n**Pipeline Building Flow**:\n\n::\n\n                            How Pipeline is Built and Executed\n   ================================================================================\n   \n   Step 1: User Defines Pipeline (Python Code)\n   --------------------------------------------\n   \n       pipeline = Pipeline(\"grpo_training_pipeline\")\n       \n       pipeline.add_node(\"rollout_actor\", func=\"...:DAGWorker.generate\", deps=[])\n              .add_node(\"function_reward\", func=\"...:DAGWorker.compute_reward\", ...)\n              .add_node(\"calculate_advantages\", func=\"...:DAGWorker.compute_advantage\", ...)\n              .add_node(\"actor_old_log_prob\", func=\"...:DAGWorker.compute_old_log_prob\", ...)\n              .add_node(\"reference_log_prob\", func=\"...:DAGWorker.compute_ref_log_prob\", ...)\n              .add_node(\"actor_train\", func=\"...:DAGWorker.train_actor\", ...)\n   \n                                            |\n                                            | pipeline.build()\n                                            v\n   \n   Step 2: Build TaskGraph (Validation + Adjacency Lists)\n   ------------------------------------------------------\n   \n       TaskGraph                          Adjacency Lists (adj)\n       +--------------------+             +------------------------------------------+\n       | graph_id: \"grpo..\" |             | rollout_actor      -> [function_reward]  |\n       |                    |             | function_reward    -> [calculate_adv.]   |\n       | nodes: {           |             | calculate_adv.     -> [actor_old_log]    |\n       |   \"rollout_actor\", |             | actor_old_log      -> [reference_log]    |\n       |   \"function_reward\"|             | reference_log      -> [actor_train]      |\n       |   \"calculate_adv.\",|             | actor_train        -> []                 |\n       |   ...              |             +------------------------------------------+\n       | }                  |\n       +--------------------+\n                                            |\n                                            | TaskScheduler.schedule()\n                                            v\n   \n   Step 3: TaskScheduler Assigns to Workers\n   ----------------------------------------\n   \n       +------------------------------------------------------------------------+\n       |  TaskScheduler                                                         |\n       |                                                                        |\n       |  Input: TaskGraph + num_workers                                        |\n       |                                                                        |\n       |  1. discover_and_split_parallel_paths(graph) -> Split parallel branches|\n       |  2. Apportion workers to subgraphs (param_aware / even)                |\n       |  3. Assign each worker a TaskGraph copy                                |\n       |                                                                        |\n       |  Output: Dict[rank, TaskGraph] (rank_taskgraph_mapping)                |\n       +------------------------------------------------------------------------+\n   \n                          +-------------------------------------------+\n                          |           rank_taskgraph_mapping          |\n                          +-------------------------------------------+\n                          |  rank 0  ->  TaskGraph (copy)             |\n                          |  rank 1  ->  TaskGraph (copy)             |\n                          |  rank 2  ->  TaskGraph (copy)             |\n                          |  ...     ->  ...                          |\n                          |  rank N  ->  TaskGraph (copy)             |\n                          +-------------------------------------------+\n                                            |\n                                            | DAGWorker receives TaskGraph\n                                            v\n   \n   Step 4: DAGWorker Executes TaskGraph\n   ------------------------------------\n   \n       +------------------------------------------------------------------------+\n       |  DAGWorker.execute_task_graph()                                        |\n       |                                                                        |\n       |  for each training step:                                               |\n       |      1. batch = DataLoader.run()                                       |\n       |      2. entry_nodes = taskgraph.get_entry_nodes()  # [rollout_actor]   |\n       |      3. node_queue = entry_nodes                                       |\n       |                                                                        |\n       |      while node_queue:                                                 |\n       |          cur_node = node_queue.pop(0)                                  |\n       |                                                                        |\n       |          # Execute node's function                                     |\n       |          output = cur_node.run(batch=batch, _dag_worker_instance=self) |\n       |                                                                        |\n       |          # Resolves executable_ref to actual function:                 |\n       |          # \"siirl.dag_worker.dagworker:DAGWorker.generate\"             |\n       |          #  -> DAGWorker.generate(self, batch, ...)                    |\n       |                                                                        |\n       |          # Get downstream nodes and add to queue                       |\n       |          next_nodes = taskgraph.get_downstream_nodes(cur_node.node_id) |\n       |          node_queue.extend(next_nodes)                                 |\n       |                                                                        |\n       |          # If DP size changes between nodes, use DataCoordinator       |\n       |          put_data_to_buffers() / get_data_from_buffers()               |\n       +------------------------------------------------------------------------+\n\n**Execution Order Example (GRPO)**:\n\n::\n\n                            GRPO Pipeline Execution Order\n   ================================================================================\n   \n   Topological Order:\n   \n     +------------------+      +------------------+      +---------------------+\n     |  rollout_actor   |----->| function_reward  |----->|calculate_advantages |\n     |  (Inference)     |      |    (Compute)     |      |      (Compute)      |\n     |                  |      |                  |      |                     |\n     |  NodeRole:       |      |  NodeRole:       |      |  NodeRole:          |\n     |  ROLLOUT         |      |  REWARD          |      |  ADVANTAGE          |\n     +------------------+      +------------------+      +----------+----------+\n                                                                    |\n         +----------------------------------------------------------+\n         |\n         v\n     +---------------------+      +---------------------+      +------------------+\n     | actor_old_log_prob  |----->| reference_log_prob  |----->|   actor_train    |\n     |   (Forward Only)    |      |   (Forward Only)    |      |     (Train)      |\n     |                     |      |                     |      |                  |\n     |  NodeRole: ACTOR    |      |  NodeRole: REFERENCE|      |  NodeRole: ACTOR |\n     |  only_forward=True  |      |                     |      |                  |\n     +---------------------+      +---------------------+      +------------------+\n   \n   Data flows through each node, accumulating fields in the batch:\n   \n     batch: {prompts}\n        |\n        v rollout_actor\n     batch: {prompts, responses, response_ids, response_mask}\n        |\n        v function_reward  \n     batch: {..., token_level_scores, token_level_rewards}\n        |\n        v calculate_advantages\n     batch: {..., advantages}\n        |\n        v actor_old_log_prob\n     batch: {..., old_log_probs}\n        |\n        v reference_log_prob\n     batch: {..., ref_log_prob}\n        |\n        v actor_train\n     metrics: {loss, clipfrac, kl, ...}\n\n4.1 Pipeline API\n----------------\n\nsiiRL provides a clean Pipeline API for users to define training pipelines directly in Python:\n\n.. code-block:: python\n   :caption: siirl/execution/dag/pipeline.py\n\n   class Pipeline:\n       \"\"\"Declarative Pipeline Builder\"\"\"\n       \n       def __init__(self, pipeline_id: str, description: str = \"\"):\n           self.pipeline_id = pipeline_id\n           self._nodes: Dict[str, Dict[str, Any]] = {}\n       \n       def add_node(\n           self,\n           node_id: str,\n           func: Union[str, Callable],  # Function path or direct Callable\n           deps: Optional[List[str]] = None,\n           **kwargs\n       ) -> \"Pipeline\":\n           \"\"\"Add node with method chaining support\"\"\"\n           self._nodes[node_id] = {\n               \"func\": func,\n               \"deps\": deps or [],\n               \"kwargs\": kwargs\n           }\n           return self  # Support method chaining\n       \n       def build(self) -> TaskGraph:\n           \"\"\"Build and validate TaskGraph\"\"\"\n           task_graph = TaskGraph(graph_id=self.pipeline_id)\n           # ... create nodes, build adjacency lists, validate DAG\n           return task_graph\n\n4.2 Built-in Pipeline Definitions\n---------------------------------\n\nsiiRL provides four built-in pipeline definitions in ``siirl/execution/dag/builtin_pipelines.py``:\n\n**4.2.1 GRPO Pipeline (grpo_pipeline)**\n\nStandard GRPO (Group Relative Policy Optimization) training workflow:\n\n.. code-block:: python\n   :caption: siirl/execution/dag/builtin_pipelines.py - GRPO Pipeline\n\n   def grpo_pipeline() -> TaskGraph:\n       \"\"\"\n       Standard GRPO (Group Relative Policy Optimization) pipeline.\n\n       Workflow:\n           1. rollout_actor: Generate sequences using the policy model\n           2. function_reward: Compute rewards for generated sequences\n           3. calculate_advantages: Calculate advantage estimates\n           4. actor_old_log_prob: Compute log probabilities with old policy (forward only)\n           5. reference_log_prob: Compute log probabilities with reference model\n           6. actor_train: Train the actor model\n       \"\"\"\n       pipeline = Pipeline(\"grpo_training_pipeline\", \"Standard GRPO workflow\")\n\n       pipeline.add_node(\n           \"rollout_actor\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.generate\",\n           deps=[],\n           node_type=NodeType.MODEL_INFERENCE,\n           node_role=NodeRole.ROLLOUT\n       ).add_node(\n           \"function_reward\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_reward\",\n           deps=[\"rollout_actor\"],\n           node_type=NodeType.COMPUTE,\n           node_role=NodeRole.REWARD\n       ).add_node(\n           \"calculate_advantages\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_advantage\",\n           deps=[\"function_reward\"],\n           node_type=NodeType.COMPUTE,\n           node_role=NodeRole.ADVANTAGE\n       ).add_node(\n           \"actor_old_log_prob\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob\",\n           deps=[\"calculate_advantages\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.ACTOR,\n           only_forward_compute=True\n       ).add_node(\n           \"reference_log_prob\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob\",\n           deps=[\"actor_old_log_prob\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.REFERENCE\n       ).add_node(\n           \"actor_train\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.train_actor\",\n           deps=[\"reference_log_prob\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.ACTOR\n       )\n\n       return pipeline.build()\n\n**4.2.2 PPO Pipeline (ppo_pipeline)**\n\nStandard PPO with Critic model and GAE advantage estimation:\n\n.. code-block:: python\n   :caption: siirl/execution/dag/builtin_pipelines.py - PPO Pipeline\n\n   def ppo_pipeline() -> TaskGraph:\n       \"\"\"\n       Standard PPO (Proximal Policy Optimization) pipeline.\n\n       Workflow:\n           1. rollout_actor: Generate sequences using the policy model\n           2. function_reward: Compute rewards for generated sequences\n           3. compute_value: Compute value function estimates (forward only)\n           4. calculate_advantages: Calculate GAE (Generalized Advantage Estimation)\n           5. actor_old_log_prob: Compute log probabilities with old policy (forward only)\n           6. reference_log_prob: Compute log probabilities with reference model\n           7. actor_train: Train the actor model\n           8. critic_train: Train the critic (value) model\n       \"\"\"\n       pipeline = Pipeline(\"ppo_training_pipeline\", \"Standard PPO workflow\")\n\n       pipeline.add_node(\n           \"rollout_actor\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.generate\",\n           deps=[],\n           node_type=NodeType.MODEL_INFERENCE,\n           node_role=NodeRole.ROLLOUT\n       ).add_node(\n           \"function_reward\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_reward\",\n           deps=[\"rollout_actor\"],\n           node_type=NodeType.COMPUTE,\n           node_role=NodeRole.REWARD\n       ).add_node(\n           \"compute_value\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_value\",\n           deps=[\"function_reward\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.CRITIC,\n           only_forward_compute=True\n       ).add_node(\n           \"calculate_advantages\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_advantage\",\n           deps=[\"compute_value\"],\n           node_type=NodeType.COMPUTE,\n           node_role=NodeRole.ADVANTAGE\n       ).add_node(\n           \"actor_old_log_prob\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob\",\n           deps=[\"calculate_advantages\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.ACTOR,\n           only_forward_compute=True\n       ).add_node(\n           \"reference_log_prob\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob\",\n           deps=[\"actor_old_log_prob\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.REFERENCE\n       ).add_node(\n           \"actor_train\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.train_actor\",\n           deps=[\"reference_log_prob\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.ACTOR\n       ).add_node(\n           \"critic_train\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.train_critic\",\n           deps=[\"actor_train\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.CRITIC\n       )\n\n       return pipeline.build()\n\n**4.2.3 DAPO Pipeline (dapo_pipeline)**\n\nDAPO (Data-Augmented Policy Optimization) with dynamic sampling filtering:\n\n.. code-block:: python\n   :caption: siirl/execution/dag/builtin_pipelines.py - DAPO Pipeline\n\n   def dapo_pipeline() -> TaskGraph:\n       \"\"\"\n       DAPO (Data-Augmented Policy Optimization) pipeline.\n\n       DAPO is a variant of GRPO with dynamic sampling filtering based on metric variance.\n       The key difference is that after computing rewards, we filter out trajectory groups\n       with zero variance (all correct or all incorrect) as they provide no learning signal.\n\n       Workflow:\n           1. rollout_actor: Generate sequences using the policy model\n           2. function_reward: Compute rewards for generated sequences\n           3. dynamic_sampling: DAPO-specific filtering based on metric variance\n           4. calculate_advantages: Calculate advantage estimates\n           5. actor_old_log_prob: Compute log probabilities with old policy (forward only)\n           6. reference_log_prob: Compute log probabilities with reference model\n           7. actor_train: Train the actor model\n       \"\"\"\n       pipeline = Pipeline(\"dapo_training_pipeline\", \"DAPO workflow\")\n\n       pipeline.add_node(\n           \"rollout_actor\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.generate\",\n           deps=[],\n           node_type=NodeType.MODEL_INFERENCE,\n           node_role=NodeRole.ROLLOUT\n       ).add_node(\n           \"function_reward\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_reward\",\n           deps=[\"rollout_actor\"],\n           node_type=NodeType.COMPUTE,\n           node_role=NodeRole.REWARD\n       ).add_node(\n           \"dynamic_sampling\",\n           func=\"siirl.user_interface.filter_interface.dapo.dynamic_sampling\",\n           deps=[\"function_reward\"],\n           node_type=NodeType.COMPUTE,\n           node_role=NodeRole.DYNAMIC_SAMPLING\n       ).add_node(\n           \"calculate_advantages\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_advantage\",\n           deps=[\"dynamic_sampling\"],\n           node_type=NodeType.COMPUTE,\n           node_role=NodeRole.ADVANTAGE\n       ).add_node(\n           \"actor_old_log_prob\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob\",\n           deps=[\"calculate_advantages\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.ACTOR,\n           only_forward_compute=True\n       ).add_node(\n           \"reference_log_prob\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob\",\n           deps=[\"actor_old_log_prob\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.REFERENCE\n       ).add_node(\n           \"actor_train\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.train_actor\",\n           deps=[\"reference_log_prob\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.ACTOR\n       )\n\n       return pipeline.build()\n\n**4.2.4 Embodied SRPO Pipeline (embodied_srpo_pipeline)**\n\nEmbodied AI SRPO training with data filtering and VJEPA-based reward computation:\n\n.. code-block:: python\n   :caption: siirl/execution/dag/builtin_pipelines.py - Embodied SRPO Pipeline\n\n   def embodied_srpo_pipeline() -> TaskGraph:\n       \"\"\"\n       Embodied AI GRPO training pipeline with data filtering and VJEPA-based reward computation.\n\n       Workflow:\n           1. rollout_actor: Environment rollout with embodied AI agent\n           2. embodied_sampling: Data verification and filtering\n           3. data_rebalance: Data rebalancing across workers (after filtering)\n           4. compute_reward: VJEPA-based reward computation\n           5. calculate_advantages: Calculate advantages (GRPO group-based)\n           6. actor_old_log_prob: Compute old actor log probabilities (forward only)\n           7. reference_log_prob: Compute reference model log probabilities\n           8. actor_train: Actor training with GRPO\n       \"\"\"\n       pipeline = Pipeline(\n           \"embodied_grpo_training_pipeline\",\n           \"Embodied AI GRPO training workflow with data filtering and VJEPA-based reward computation.\"\n       )\n\n       pipeline.add_node(\n           \"rollout_actor\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.generate\",\n           deps=[],\n           node_type=NodeType.MODEL_INFERENCE,\n           node_role=NodeRole.ROLLOUT\n       ).add_node(\n           \"dynaminc_sampling\",\n           func=\"siirl.user_interface.filter_interface.embodied.embodied_local_rank_sampling\",\n           deps=[\"rollout_actor\"],\n           node_type=NodeType.COMPUTE,\n           node_role=NodeRole.DYNAMIC_SAMPLING\n       ).add_node(\n           \"compute_reward\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_reward\",\n           deps=[\"dynaminc_sampling\"],\n           node_type=NodeType.COMPUTE,\n           node_role=NodeRole.REWARD\n       ).add_node(\n           \"calculate_advantages\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_advantage\",\n           deps=[\"compute_reward\"],\n           node_type=NodeType.COMPUTE,\n           node_role=NodeRole.ADVANTAGE\n       ).add_node(\n           \"actor_old_log_prob\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob\",\n           deps=[\"calculate_advantages\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.ACTOR,\n           only_forward_compute=True\n       ).add_node(\n           \"reference_log_prob\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob\",\n           deps=[\"actor_old_log_prob\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.REFERENCE\n       ).add_node(\n           \"actor_train\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.train_actor\",\n           deps=[\"reference_log_prob\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.ACTOR\n       )\n\n       return pipeline.build()\n\n**Pipeline Comparison Table**:\n\n.. list-table:: Built-in Pipeline Comparison\n   :header-rows: 1\n   :widths: 15 45 40\n\n   * - Pipeline\n     - Key Difference\n     - Use Case\n   * - **GRPO**\n     - Group-based advantage normalization\n     - Reasoning tasks, math problems\n   * - **PPO**\n     - Critic model + GAE advantage estimation\n     - General RL tasks with value function\n   * - **DAPO**\n     - Dynamic sampling to filter zero-variance groups\n     - Challenging tasks with sparse rewards\n   * - **Embodied SRPO**\n     - Environment interaction + VJEPA reward + dynamic sampling\n     - Robotics, embodied AI tasks\n\n4.3 Node Data Structure\n-----------------------\n\nEach DAG node is represented by the ``Node`` class:\n\n.. code-block:: python\n   :caption: siirl/execution/dag/node.py\n\n   class NodeType(Enum):\n       \"\"\"Define the types of nodes in the DAG.\"\"\"\n       COMPUTE = \"COMPUTE\"                    # General computing task\n       DATA_LOAD = \"DATA_LOAD\"                # Load data from DataLoader\n       ENV_INTERACT = \"ENV_INTERACT\"          # Interact with the environment\n       MODEL_INFERENCE = \"MODEL_INFERENCE\"    # Model inference (Rollout)\n       MODEL_TRAIN = \"MODEL_TRAIN\"            # Model training\n       PUT_TO_BUFFER = \"PUT_TO_BUFFER\"        # Put data into the distributed buffer\n       GET_FROM_BUFFER = \"GET_FROM_BUFFER\"    # Get data from the distributed buffer\n       BARRIER_SYNC = \"BARRIER_SYNC\"          # Global synchronization point\n       CUSTOM = \"CUSTOM\"                      # User-defined node type\n\n   class NodeRole(Enum):\n       \"\"\"Define the roles that a node plays in a distributed RL framework.\"\"\"\n       DEFAULT = \"DEFAULT\"                # Default role\n       ACTOR = \"ACTOR\"                    # Actor model (policy)\n       ADVANTAGE = \"ADVANTAGE\"            # Advantage computation\n       CRITIC = \"CRITIC\"                  # Critic model (value function)\n       ROLLOUT = \"ROLLOUT\"                # Rollout inference engine\n       REFERENCE = \"REFERENCE\"            # Reference model (for KL)\n       REWARD = \"REWARD\"                  # Reward computation\n       DYNAMIC_SAMPLING = \"DYNAMIC_SAMPLING\"  # Dynamic sampling in databuffer (DAPO/Embodied)\n\n   class NodeStatus(Enum):\n       \"\"\"Define the execution status of a DAG node.\"\"\"\n       PENDING = \"PENDING\"      # Waiting for dependencies to complete\n       READY = \"READY\"          # Dependencies completed, ready to execute\n       RUNNING = \"RUNNING\"      # Currently executing\n       COMPLETED = \"COMPLETED\"  # Execution completed successfully\n       FAILED = \"FAILED\"        # Execution failed\n       SKIPPED = \"SKIPPED\"      # Skipped\n\n   class Node:\n       \"\"\"Represents a node (task unit) in the DAG.\"\"\"\n       \n       def __init__(\n           self,\n           node_id: str,\n           node_type: NodeType,\n           node_role: NodeRole = NodeRole.DEFAULT,\n           only_forward_compute: bool = False,  # Forward only, no weight update\n           agent_group: int = 0,                # Multi-agent scenario grouping\n           dependencies: Optional[List[str]] = None,\n           config: Optional[Dict[str, Any]] = None,\n           executable_ref: Optional[str] = None,  # Function path \"module:Class.method\"\n           filter_plugin: Optional[Callable] = None,  # Filter function for data\n           agent_options: AgentArguments = None,\n           retry_limit: int = 0,\n       ):\n           self.node_id = node_id\n           self.node_type = node_type\n           self.node_role = node_role\n           self.only_forward_compute = only_forward_compute\n           self.agent_group = agent_group\n           self.dependencies = dependencies or []\n           self.config = config or {}\n           self.executable_ref = executable_ref\n           self.retry_limit = retry_limit\n           self._executable: Optional[Callable] = None\n           self.status = NodeStatus.PENDING\n           \n           # Resolve executable function from path\n           if self.executable_ref:\n               self._resolve_executable()\n       \n       def _resolve_executable(self) -> None:\n           \"\"\"Dynamically import and obtain the executable function.\n           \n           Supports two formats:\n           1. \"module.path:ClassName.method\" - imports module.path, gets ClassName.method\n           2. \"module.path.function\" - imports module.path, gets function\n           \"\"\"\n           if \":\" in self.executable_ref:\n               module_path, attr_path = self.executable_ref.split(\":\", 1)\n               module = importlib.import_module(module_path)\n               obj = module\n               for attr_name in attr_path.split(\".\"):\n                   obj = getattr(obj, attr_name)\n               self._executable = obj\n           else:\n               module_path, function_name = self.executable_ref.rsplit(\".\", 1)\n               module = importlib.import_module(module_path)\n               self._executable = getattr(module, function_name)\n       \n       def run(self, **kwargs) -> Any:\n           \"\"\"Execute the task of the node.\"\"\"\n           if self.executable:\n               return self.executable(**kwargs)\n\n4.4 TaskGraph Data Structure\n----------------------------\n\n``TaskGraph`` represents the entire training workflow as a DAG:\n\n.. code-block:: python\n   :caption: siirl/execution/dag/task_graph.py\n\n   class TaskGraph:\n       \"\"\"Directed Acyclic Graph representing training workflow\"\"\"\n       \n       def __init__(self, graph_id: str):\n           self.graph_id = graph_id\n           self.nodes: Dict[str, Node] = {}       # node_id -> Node\n           self.adj: Dict[str, List[str]] = {}    # Forward adjacency: node -> dependents\n           self.rev_adj: Dict[str, List[str]] = {} # Reverse adjacency: node -> dependencies\n       \n       def add_node(self, node: Node) -> None:\n           \"\"\"Add node to graph\"\"\"\n           self.nodes[node.node_id] = node\n           self._update_adj_for_node(node)\n       \n       def get_topological_sort(self) -> List[str]:\n           \"\"\"Topological sort using Kahn's algorithm\"\"\"\n           # ... implement Kahn's algorithm\n       \n       def validate_graph(self) -> Tuple[bool, Optional[str]]:\n           \"\"\"Validate DAG validity (no cycles, dependencies exist)\"\"\"\n           # 1. Check all dependencies exist\n           # 2. Use topological sort to detect cycles\n           try:\n               self.get_topological_sort()\n               return True, None\n           except ValueError as e:\n               return False, str(e)\n       \n       def get_entry_nodes(self) -> List[Node]:\n           \"\"\"Get entry nodes (no dependencies)\"\"\"\n           return [node for node_id, node in self.nodes.items() \n                   if not self.rev_adj.get(node_id)]\n       \n       def get_downstream_nodes(self, node_id: str) -> List[Node]:\n           \"\"\"Get downstream nodes\"\"\"\n           return self.get_dependents(node_id)\n\n4.5 TaskScheduler\n-----------------\n\n``TaskScheduler`` is responsible for assigning TaskGraph to each worker:\n\n.. code-block:: python\n   :caption: siirl/execution/scheduler/task_scheduler.py\n\n   class TaskScheduler:\n       \"\"\"Task Scheduler: Assign TaskGraph to Workers\"\"\"\n       \n       def __init__(self, num_physical_nodes: int, gpus_per_node: int):\n           self.num_physical_nodes = num_physical_nodes\n           self.gpus_per_node = gpus_per_node\n           self.num_workers = num_physical_nodes * gpus_per_node\n           \n           # State variables\n           self.worker_to_graph_assignment: Dict[int, Optional[TaskGraph]] = {}\n           self.node_active_worker_count: Dict[int, int] = defaultdict(int)\n           self.node_free_gpus: Dict[int, List[int]] = defaultdict(list)\n       \n       def schedule_and_assign_tasks(\n           self,\n           original_task_graphs: List[TaskGraph],\n           apportion_strategy: str = \"param_aware\",  # or \"even\"\n           consider_node_cohesion: bool = True,      # Same task on same physical node\n           consider_node_load: bool = True,          # Prefer lower load nodes\n       ) -> Dict[int, Optional[TaskGraph]]:\n           \"\"\"Schedule tasks to each worker\"\"\"\n           \n           # 1. Split original graphs into irreducible subgraphs\n           all_subgraphs = []\n           for graph in original_task_graphs:\n               subgraphs = discover_and_split_parallel_paths(graph)\n               all_subgraphs.extend(subgraphs)\n           \n           # 2. Estimate subgraph sizes and sort\n           subgraphs_with_sizes = sorted(\n               [(sg, estimate_graph_model_params(sg)) for sg in all_subgraphs],\n               key=lambda x: x[1],\n               reverse=True\n           )\n           \n           # 3. Apportion worker counts\n           workers_per_task = self._apportion_workers_to_tasks(\n               subgraphs_with_sizes,\n               self.num_workers,\n               apportion_strategy\n           )\n           \n           # 4. Place workers (considering cohesion and load balancing)\n           for task_graph, _ in subgraphs_with_sizes:\n               num_workers = workers_per_task[task_graph.graph_id]\n               for _ in range(num_workers):\n                   best_worker = self._find_best_worker(\n                       task_graph, consider_node_cohesion, consider_node_load\n                   )\n                   self.worker_to_graph_assignment[best_worker] = task_graph\n           \n           return self.worker_to_graph_assignment\n\n**Scheduling Strategy Comparison**:\n\n.. list-table:: Scheduling Strategies\n   :header-rows: 1\n   :widths: 20 40 40\n\n   * - Strategy\n     - Description\n     - Use Case\n   * - **even**\n     - Distribute workers evenly among tasks\n     - Similar task workloads\n   * - **param_aware**\n     - Distribute based on model parameter ratio\n     - Large variance in task sizes\n\n4.6 Task Graph Splitting (task_loader.py)\n-----------------------------------------\n\nThe ``task_loader.py`` module provides utilities for analyzing and splitting complex TaskGraphs:\n\n.. code-block:: python\n   :caption: siirl/execution/dag/task_loader.py\n\n   def discover_and_split_parallel_paths(src_task_graph: TaskGraph) -> List[TaskGraph]:\n       \"\"\"\n       Discovers and splits a TaskGraph into irreducible subgraphs by iteratively\n       identifying and splitting fan-out and re-converging parallel paths.\n       \n       Args:\n           src_task_graph: The original TaskGraph to be analyzed and split\n       \n       Returns:\n           List[TaskGraph]: A list of irreducible subgraph TaskGraph objects\n       \"\"\"\n       # 1. Try to split by fan-out to distinct exits\n       graphs_after_fan_out = split_by_fan_out_to_exits(current_graph, iteration_counter)\n       \n       # 2. If no fan-out split, try to split by re-converging paths\n       graphs_after_reconverge = split_by_reconverging_paths(current_graph, iteration_counter)\n       \n       # 3. If no split possible, graph is irreducible\n       return final_irreducible_graphs\n\nThis enables automatic parallelization of independent pipeline branches across different worker groups.\n\n----\n\n.. _sec5_dag_worker:\n\n5. DAG Worker Deep Dive\n=======================\n\nDAG Worker is the core execution unit of siiRL, with one DAG Worker running per GPU.\n\n5.1 DAGWorker Class Structure\n-----------------------------\n\n.. code-block:: python\n   :caption: siirl/dag_worker/dagworker.py\n\n   class DAGWorker(Worker):\n       \"\"\"DAG Execution Unit, one instance per GPU\"\"\"\n       \n       def __init__(\n           self,\n           config: SiiRLArguments,\n           process_group_manager: ProcessGroupManager,\n           taskgraph_mapping: Dict[int, TaskGraph],\n           data_coordinator: ray.actor.ActorHandle,\n           metric_worker: ray.actor.ActorHandle,\n       ):\n           # Configuration\n           self.config = config\n           self.process_group_manager = process_group_manager\n           self.taskgraph_mapping = taskgraph_mapping\n           self.data_coordinator = data_coordinator\n           \n           # State\n           self.global_steps = 0\n           self.workers: Dict[str, Any] = {}  # Node role -> Worker instance\n           self.multi_agent_group: Dict[int, Dict[NodeRole, Any]] = defaultdict(dict)\n           self.process_groups: Dict[str, ProcessGroup] = {}\n           self.internal_data_cache: Dict[str, Any] = {}\n           \n           # Initialize\n           self._initialize_worker()\n\n5.2 Initialization Flow\n-----------------------\n\nDAGWorker initialization is divided into two phases:\n\n**Phase 1: _initialize_worker() in __init__**\n\n.. code-block:: python\n\n   def _initialize_worker(self):\n       \"\"\"Initialize all Worker components\"\"\"\n       \n       # 1. Validate rank and get assigned TaskGraph\n       self._rank = get_and_validate_rank()\n       self.taskgraph = get_taskgraph_for_rank(self._rank, self.taskgraph_mapping)\n       \n       # 2. Set up distributed environment\n       self._setup_distributed_environment()\n       \n       # 3. Initialize Tokenizer\n       self._setup_tokenizers()\n       \n       # 4. Initialize DataLoader\n       self._setup_dataloader()\n       \n       # 5. Initialize Reward Manager\n       self._setup_reward_managers()\n       \n       # 6. Create role -> Worker class mapping\n       self._setup_role_worker_mapping()\n       \n       # 7. Instantiate node Workers\n       self._initialize_node_workers()\n\n**Phase 2: init_graph() method**\n\n.. code-block:: python\n\n   def init_graph(self):\n       \"\"\"Load model weights, restore checkpoint\"\"\"\n       \n       # 1. Load model weights to GPU\n       self._load_model_weights()\n       \n       # 2. Set up weight sharing (Actor-Rollout)\n       self._setup_sharding_manager()\n       \n       # 3. Initialize async rollout (if configured)\n       self._setup_async_rollout()\n       \n       # 4. Initialize multi-agent loop (if configured)\n       self._setup_multi_agent_loop()\n       \n       # 5. Initialize validator\n       self._init_validator()\n       \n       # 6. Initialize checkpoint manager and restore\n       self._init_checkpoint_manager()\n       self.global_steps = self.checkpoint_manager.load_checkpoint()\n       \n       # 7. Global synchronization\n       dist.barrier(self._gather_group)\n\n5.3 Training Loop\n-----------------\n\n.. code-block:: python\n   :caption: DAGWorker Training Loop Core Logic\n\n   def execute_task_graph(self):\n       \"\"\"Main entry: Execute DAG training pipeline\"\"\"\n       \n       # Optional pre-training validation\n       if self.config.trainer.val_before_train:\n           self.validator.validate(global_step=self.global_steps)\n       \n       # Main training loop\n       self._run_training_loop()\n   \n   def _run_training_loop(self):\n       \"\"\"Main training loop\"\"\"\n       \n       for epoch in range(self.config.trainer.total_epochs):\n           for batch_idx in range(self.dataloader.num_train_batches):\n               # Execute one training step\n               ordered_metrics = self._run_training_step(epoch, batch_idx)\n               self.global_steps += 1\n               \n               # Save checkpoint\n               if self.global_steps % self.config.trainer.save_freq == 0:\n                   self.checkpoint_manager.save_checkpoint(self.global_steps)\n               \n               # Execute validation\n               if self.global_steps % self.config.trainer.test_freq == 0:\n                   self.validator.validate(global_step=self.global_steps)\n               \n               # Log metrics\n               if self._rank == 0 and self.logger:\n                   self.logger.log(data=ordered_metrics, step=self.global_steps)\n\n5.4 Single Training Step Execution\n----------------------------------\n\n.. code-block:: python\n   :caption: _run_training_step() Explained\n\n   def _run_training_step(self, epoch: int, batch_idx: int) -> Optional[Dict]:\n       \"\"\"Execute a single training step\"\"\"\n       \n       # 1. Get data from DataLoader\n       batch = preprocess_dataloader(\n           self.dataloader.run(epoch=epoch, is_validation_step=False),\n           self.config.actor_rollout_ref.rollout.n\n       )\n       \n       # 2. Get DAG entry nodes\n       node_queue = self.taskgraph.get_entry_nodes()\n       entry_node_id = node_queue[0].node_id\n       visited_nodes = set()\n       \n       # 3. Graph traversal execution\n       while node_queue:\n           cur_node = node_queue.pop(0)\n           if cur_node.node_id in visited_nodes:\n               continue\n           visited_nodes.add(cur_node.node_id)\n           \n           # 3.1 Get node's DP/TP/PP info\n           cur_dp_size, cur_dp_rank, cur_tp_rank, cur_tp_size, cur_pp_rank, cur_pp_size = \\\n               self._get_node_dp_info(cur_node)\n           \n           # 3.2 Non-entry nodes get data from buffer\n           if cur_node.node_id != entry_node_id:\n               batch = self.get_data_from_buffers(\n                   key=cur_node.node_id,\n                   cur_dp_size=cur_dp_size,\n                   cur_dp_rank=cur_dp_rank\n               )\n           \n           # 3.3 Execute node\n           if cur_node.executable and batch is not None:\n               node_output = cur_node.run(\n                   batch=batch,\n                   config=self.config,\n                   process_group=self._get_node_process_group(cur_node),\n                   agent_group=self.multi_agent_group[cur_node.agent_group],\n                   _dag_worker_instance=self\n               )\n           else:\n               node_output = NodeOutput(batch=batch)\n           \n           # 3.4 Process output, pass to downstream nodes\n           if next_nodes := self.taskgraph.get_downstream_nodes(cur_node.node_id):\n               next_node = next_nodes[0]\n               next_dp_size = self._get_node_dp_info(next_node)[0]\n               \n               # If DP size changes, need DataCoordinator for redistribution\n               self.put_data_to_buffers(\n                   key=next_node.node_id,\n                   data=node_output.batch,\n                   source_dp_size=cur_dp_size,\n                   dest_dp_size=next_dp_size\n               )\n               \n               # Add downstream nodes to queue\n               for n in next_nodes:\n                   if n.node_id not in visited_nodes:\n                       node_queue.append(n)\n       \n       # 4. Clean up caches\n       self._cleanup_step_buffers()\n       \n       # 5. Collect and return metrics\n       return self._collect_metrics()\n\n5.5 Node Execution Methods\n--------------------------\n\nDAGWorker provides a series of node execution methods, each corresponding to a node role:\n\n.. code-block:: python\n   :caption: Node Execution Methods\n\n   # Rollout: Generate sequences\n   def generate(self, config, batch: TensorDict, **kwargs) -> NodeOutput:\n       \"\"\"Generate sequences using the Rollout model\"\"\"\n       agent_group = kwargs.pop(\"agent_group\")\n       is_embodied = self.config.actor_rollout_ref.model.model_type == \"embodied\"\n       \n       if is_embodied:\n           return self.generate_embodied_mode(agent_group, batch, **kwargs)\n       \n       if self.rollout_mode == 'sync':\n           gen_output = agent_group[NodeRole.ROLLOUT].generate_sequences(batch)\n           batch = batch.update(gen_output)\n           return NodeOutput(batch=batch, metrics=gen_output[\"metrics\"])\n       else:\n           return self.generate_async_mode(batch)\n   \n   # Reward: Compute rewards\n   def compute_reward(self, config, batch: TensorDict, **kwargs) -> NodeOutput:\n       \"\"\"Compute rewards for generated sequences\"\"\"\n       reward_tensor, extra_infos = compute_reward(batch, self.reward_fn)\n       batch[\"token_level_scores\"] = reward_tensor\n       \n       if config.algorithm.use_kl_in_reward:\n           batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl_in_reward, ...)\n       else:\n           batch[\"token_level_rewards\"] = batch[\"token_level_scores\"]\n       \n       return NodeOutput(batch=batch, metrics=metrics)\n   \n   # Advantage: Compute advantages\n   def compute_advantage(self, config, batch: TensorDict, **kwargs) -> NodeOutput:\n       \"\"\"Compute GAE/GRPO/CPGD advantages\"\"\"\n       return NodeOutput(\n           batch=compute_advantage(\n               batch,\n               adv_estimator=config.algorithm.adv_estimator,\n               gamma=config.algorithm.gamma,\n               lam=config.algorithm.lam,\n               norm_adv_by_std_in_grpo=config.algorithm.norm_adv_by_std_in_grpo\n           )\n       )\n   \n   # Actor Forward: Compute old policy log prob\n   def compute_old_log_prob(self, config, batch: TensorDict, **kwargs) -> NodeOutput:\n       \"\"\"Compute log probabilities before policy update\"\"\"\n       agent_group = kwargs.pop(\"agent_group\")\n       processed_data = agent_group[NodeRole.ACTOR].compute_log_prob(batch)\n       return NodeOutput(batch=processed_data, metrics=processed_data.get(\"metrics\", {}))\n   \n   # Reference: Compute reference model log prob\n   def compute_ref_log_prob(self, config, batch: TensorDict, **kwargs) -> NodeOutput:\n       \"\"\"Compute reference model log probabilities\"\"\"\n       agent_group = kwargs.pop(\"agent_group\")\n       processed_data = agent_group[NodeRole.REFERENCE].compute_ref_log_prob(batch)\n       return NodeOutput(batch=processed_data, metrics=processed_data[\"metrics\"])\n   \n   # Actor Train: Train Actor model\n   def train_actor(self, config, batch: TensorDict, **kwargs) -> NodeOutput:\n       \"\"\"Execute Actor model training step\"\"\"\n       agent_group = kwargs.pop(\"agent_group\")\n       processed_data = agent_group[NodeRole.ACTOR].update_actor(batch)\n       return NodeOutput(batch=processed_data, metrics=processed_data[\"metrics\"])\n   \n   # Critic Train: Train Critic model (PPO)\n   def train_critic(self, config, batch: TensorDict, **kwargs) -> NodeOutput:\n       \"\"\"Execute Critic model training step\"\"\"\n       agent_group = kwargs.pop(\"agent_group\")\n       processed_data = agent_group[NodeRole.CRITIC].update_critic(batch)\n       return NodeOutput(batch=processed_data, metrics=processed_data[\"metrics\"])\n\n----\n\n.. _sec6_data_coordinator:\n\n6. Data Coordinator Deep Dive\n=============================\n\nData Coordinator is the core component of siiRL's fully distributed data management.\n\n6.1 Design Philosophy\n---------------------\n\n**Why do we need Data Coordinator?**\n\nIn traditional frameworks, all intermediate data (Rollout outputs, Reward results, etc.) must pass through a central controller for redistribution, causing severe I/O bottlenecks. siiRL's Data Coordinator adopts a different design:\n\n1. **Store only metadata and references**: Actual data is stored in Ray Object Store\n2. **Support flexible sampling strategies**: Custom sampling via filter_plugin\n3. **Automatic load balancing**: Optimize sequence length distribution via balance_partitions\n\n6.2 DataCoordinator Implementation\n----------------------------------\n\n.. code-block:: python\n   :caption: siirl/data_coordinator/data_buffer.py\n\n   @ray.remote\n   class DataCoordinator:\n       \"\"\"Global singleton data coordination Actor\"\"\"\n       \n       def __init__(self, nnodes: int, ppo_mini_batch_size: int, world_size: int):\n           self.nnodes = nnodes\n           self.ppo_mini_batch_size = ppo_mini_batch_size\n           self.world_size = world_size\n           \n           # Efficiently store metadata and references using deque\n           self._sample_queue: deque[Tuple[SampleInfo, ray.ObjectRef]] = deque()\n           self.lock = asyncio.Lock()\n           self._cache = []\n       \n       async def put_batch(\n           self, \n           sample_infos: List[SampleInfo], \n           sample_refs: List[ray.ObjectRef],\n           caller_node_id: Optional[str] = None\n       ):\n           \"\"\"Register a batch of sample references and metadata\"\"\"\n           \n           # Inject caller node ID (for subsequent routing)\n           if caller_node_id is None:\n               caller_node_id = ray.get_runtime_context().get_node_id()\n           \n           for i in range(len(sample_infos)):\n               if sample_infos[i].node_id is None:\n                   sample_infos[i].node_id = caller_node_id\n           \n           async with self.lock:\n               self._sample_queue.extend(zip(sample_infos, sample_refs))\n       \n       async def get_batch(\n           self,\n           batch_size: int,\n           dp_rank: int,\n           filter_plugin: Optional[Callable[[SampleInfo], bool]] = None,\n           balance_partitions: Optional[int] = None\n       ) -> List[ray.ObjectRef]:\n           \"\"\"Get a batch of sample ObjectRefs\"\"\"\n           \n           async with self.lock:\n               # 1. If cached, return directly\n               if len(self._cache) > 0:\n                   return self._cache[dp_rank]\n               \n               # 2. No filter, use efficient FIFO\n               if not filter_plugin:\n                   batch_items = []\n                   while self._sample_queue:\n                       item = self._sample_queue.popleft()\n                       batch_items.append(item)\n                   \n                   # Apply length balancing\n                   if balance_partitions and balance_partitions > 1:\n                       batch_refs = self._apply_length_balancing(batch_items, balance_partitions)\n                   else:\n                       batch_refs = [item[1] for item in batch_items]\n                   \n                   self._cache = batch_refs\n                   return self._cache[:batch_size]\n               \n               # 3. With filter, execute filtering\n               else:\n                   potential_items = [item for item in self._sample_queue \n                                      if filter_plugin(item[0])]\n                   \n                   global_batch_size = batch_size * balance_partitions\n                   if len(potential_items) < global_batch_size:\n                       return []\n                   \n                   potential_items = potential_items[:global_batch_size]\n                   \n                   # Remove selected items from queue\n                   refs_to_remove = {item[1] for item in potential_items}\n                   self._sample_queue = deque(\n                       item for item in self._sample_queue if item[1] not in refs_to_remove\n                   )\n                   \n                   # Apply length balancing and cache\n                   if balance_partitions and balance_partitions > 1:\n                       batch_refs = self._apply_length_balancing(potential_items, balance_partitions)\n                   else:\n                       batch_refs = [item[1] for item in potential_items]\n                   \n                   for rank in range(balance_partitions):\n                       self._cache.append(batch_refs[rank * batch_size: (rank + 1) * batch_size])\n                   \n                   return self._cache[dp_rank]\n\n6.3 SampleInfo Metadata\n-----------------------\n\n.. code-block:: python\n   :caption: siirl/data_coordinator/sample.py\n\n   @dataclass\n   class SampleInfo:\n       \"\"\"Sample metadata for routing and sampling\"\"\"\n       \n       sum_tokens: int = 0          # Total tokens (prompt + response)\n       prompt_length: int = 0       # Prompt length\n       response_length: int = 0     # Response length\n       uid: str = \"\"                # Unique identifier\n       node_id: Optional[str] = None  # Source node ID\n       dict_info: Dict[str, Any] = field(default_factory=dict)  # Extended info\n           # Common fields:\n           # - 'key': Target node ID\n           # - 'source_dp_size': Source DP size\n\n\n6.4 DAGWorker Data Flow Operations\n----------------------------------\n\n.. code-block:: python\n   :caption: Data flow methods in DAGWorker\n\n   def put_data_to_buffers(\n       self,\n       key: str,\n       data: TensorDict,\n       source_dp_size: int,\n       dest_dp_size: int,\n       enforce_buffer: bool = False\n   ):\n       \"\"\"Put data into DataCoordinator\"\"\"\n       \n       # Same source and dest DP size and not forcing buffer, use local cache\n       if source_dp_size == dest_dp_size and not enforce_buffer:\n           self.internal_data_cache[key] = data\n       else:\n           # Convert to Sample list\n           samples = Dict2Samples(data)\n           \n           # Create metadata\n           sample_infos = []\n           for sample in samples:\n               sample_infos.append(SampleInfo(\n                   sum_tokens=int(sample.attention_mask.sum()),\n                   uid=str(sample.uid),\n                   dict_info={'key': key, 'source_dp_size': source_dp_size}\n               ))\n           \n           # Upload to Ray Object Store\n           sample_refs = [ray.put(sample) for sample in samples]\n           \n           # Register with DataCoordinator\n           caller_node_id = ray.get_runtime_context().get_node_id()\n           self.data_coordinator.put_batch.remote(sample_infos, sample_refs, caller_node_id)\n   \n   def get_data_from_buffers(\n       self,\n       key: str,\n       cur_dp_size: int,\n       cur_dp_rank: int\n   ) -> Optional[TensorDict]:\n       \"\"\"Get data from DataCoordinator\"\"\"\n       \n       # Check local cache first\n       if key in self.internal_data_cache:\n           return self.internal_data_cache.pop(key)\n       \n       # Define filter function\n       def key_filter(sample_info: SampleInfo) -> bool:\n           return sample_info.dict_info.get('key') == key\n       \n       # Calculate adjusted batch size\n       rollout_n = self.config.actor_rollout_ref.rollout.n\n       adjusted_batch_size = int(self.config.data.train_batch_size * rollout_n / cur_dp_size)\n       \n       # Get from DataCoordinator\n       sample_refs = ray.get(self.data_coordinator.get_batch.remote(\n           adjusted_batch_size,\n           cur_dp_rank,\n           filter_plugin=key_filter,\n           balance_partitions=cur_dp_size\n       ))\n       \n       if not sample_refs:\n           return None\n       \n       # Get actual data and collate\n       samples = ray.get(sample_refs)\n       return Samples2Dict(samples)\n\n----\n\n.. _sec7_engine:\n\n7. Engine Model Execution\n=========================\n\nThe Engine module contains all model Worker implementations, supporting both FSDP and Megatron training backends.\n\n7.1 Engine Module Structure\n---------------------------\n\n::\n\n   engine/\n   ├── actor/                    # Actor models\n   │   ├── base.py               # Base class\n   │   ├── dp_actor.py           # FSDP Actor\n   │   ├── megatron_actor.py     # Megatron Actor\n   │   └── embodied_actor.py     # Embodied Actor\n   ├── critic/                   # Critic models\n   │   ├── base.py\n   │   ├── dp_critic.py\n   │   └── megatron_critic.py\n   ├── rollout/                  # Rollout engine\n   │   ├── base.py\n   │   ├── vllm_rollout/         # vLLM backend\n   │   ├── sglang_rollout/       # SGLang backend\n   │   ├── hf_rollout.py         # HuggingFace backend\n   │   └── embodied_rollout.py   # Embodied Rollout\n   ├── reward_model/             # Reward models\n   ├── reward_manager/           # Reward managers\n   │   ├── naive.py              # Simple Reward\n   │   ├── parallel.py           # Parallel Reward Model\n   │   ├── dapo.py               # DAPO Reward\n   │   └── embodied.py           # Embodied Reward\n   ├── sharding_manager/         # Weight sharding management\n   │   ├── base.py\n   │   ├── fsdp_hf.py\n   │   ├── fsdp_sglang.py\n   │   ├── fsdp_vllm.py\n   │   ├── megatron_sglang.py\n   │   └── megatron_vllm.py\n   ├── fsdp_workers.py           # FSDP Worker factory\n   └── megatron_workers.py       # Megatron Worker factory\n\n7.2 Worker Base Class\n---------------------\n\nAll model Workers inherit from a unified base class:\n\n.. code-block:: python\n   :caption: siirl/engine/base_worker/base/base_worker.py\n\n   class Worker:\n       \"\"\"Abstract base class for all Workers\"\"\"\n       \n       @property\n       def world_size(self) -> int:\n           \"\"\"Get global world size\"\"\"\n           if not dist.is_initialized():\n               return 1\n           return dist.get_world_size()\n       \n       def init_model(self):\n           \"\"\"Initialize model weights (implemented by subclasses)\"\"\"\n           raise NotImplementedError\n\n7.3 Actor Worker\n----------------\n\nActor Worker is responsible for policy model training:\n\n.. code-block:: python\n   :caption: siirl/engine/actor/dp_actor.py (simplified)\n\n   class FSDPActor(Actor):\n       \"\"\"FSDP Distributed Actor\"\"\"\n       \n       def __init__(self, config, process_group: ProcessGroup):\n           self.config = config\n           self.process_group = process_group\n           \n           # Model related\n           self.model = None\n           self.optimizer = None\n           self.scheduler = None\n       \n       def init_model(self):\n           \"\"\"Initialize model, optimizer, scheduler\"\"\"\n           \n           # 1. Load model\n           self.model = self._load_model()\n           \n           # 2. Apply FSDP wrapping\n           self.model = FSDP(\n               self.model,\n               sharding_strategy=ShardingStrategy.FULL_SHARD,\n               process_group=self.process_group,\n               mixed_precision=...,\n           )\n           \n           # 3. Create optimizer\n           self.optimizer = create_optimizer(self.model, self.config.actor.optim)\n           \n           # 4. Create learning rate scheduler\n           self.scheduler = create_scheduler(self.optimizer, self.config.actor.optim)\n       \n       def compute_log_prob(self, batch: TensorDict) -> TensorDict:\n           \"\"\"Compute log probabilities (forward pass, no weight update)\"\"\"\n           \n           with torch.no_grad():\n               outputs = self.model(\n                   input_ids=batch[\"input_ids\"],\n                   attention_mask=batch[\"attention_mask\"],\n               )\n               \n               log_probs = compute_log_prob_from_logits(\n                   outputs.logits, batch[\"responses\"], batch[\"response_mask\"]\n               )\n           \n           batch[\"old_log_probs\"] = log_probs\n           return batch\n       \n       def update_actor(self, batch: TensorDict) -> TensorDict:\n           \"\"\"Execute Actor training step\"\"\"\n           \n           metrics = {}\n           total_loss = 0.0\n           \n           for _ in range(self.config.actor.ppo_epochs):\n               # Forward pass\n               outputs = self.model(\n                   input_ids=batch[\"input_ids\"],\n                   attention_mask=batch[\"attention_mask\"],\n               )\n               \n               # Compute current log probabilities\n               log_probs = compute_log_prob_from_logits(\n                   outputs.logits, batch[\"responses\"], batch[\"response_mask\"]\n               )\n               \n               # Compute policy loss\n               pg_loss, pg_clipfrac, ppo_kl, _ = compute_policy_loss(\n                   old_log_prob=batch[\"old_log_probs\"],\n                   log_prob=log_probs,\n                   advantages=batch[\"advantages\"],\n                   response_mask=batch[\"response_mask\"],\n                   cliprange=self.config.actor.clip_ratio,\n               )\n               \n               # Compute entropy loss\n               entropy_loss = compute_entropy_loss(outputs.logits, batch[\"response_mask\"])\n               \n               # Total loss\n               loss = pg_loss - self.config.actor.entropy_coef * entropy_loss\n               \n               # Backward pass\n               self.optimizer.zero_grad()\n               loss.backward()\n               \n               # Gradient clipping\n               if self.config.actor.max_grad_norm:\n                   torch.nn.utils.clip_grad_norm_(\n                       self.model.parameters(), self.config.actor.max_grad_norm\n                   )\n               \n               # Optimizer step\n               self.optimizer.step()\n               self.scheduler.step()\n               \n               total_loss += loss.item()\n           \n           metrics[\"actor/loss\"] = total_loss / self.config.actor.ppo_epochs\n           metrics[\"actor/pg_clipfrac\"] = pg_clipfrac.item()\n           metrics[\"actor/ppo_kl\"] = ppo_kl.item()\n           \n           batch[\"metrics\"] = metrics\n           return batch\n\n7.4 Rollout Worker\n------------------\n\nRollout Worker is responsible for sequence generation:\n\n.. code-block:: python\n   :caption: siirl/engine/rollout/vllm_rollout/vllm_rollout.py (simplified)\n\n   class VLLMRollout:\n       \"\"\"vLLM Inference Backend\"\"\"\n       \n       def __init__(self, config, process_group: ProcessGroup):\n           self.config = config\n           self.process_group = process_group\n           \n           # vLLM LLM instance\n           self.llm = None\n           self.tokenizer = None\n       \n       def init_model(self):\n           \"\"\"Initialize vLLM engine\"\"\"\n           \n           from vllm import LLM, SamplingParams\n           \n           self.llm = LLM(\n               model=self.config.model.path,\n               tensor_parallel_size=self.config.rollout.tensor_model_parallel_size,\n               trust_remote_code=True,\n               dtype=self.config.model.dtype,\n           )\n           \n           self.tokenizer = self.llm.get_tokenizer()\n       \n       def generate_sequences(self, batch: TensorDict) -> TensorDict:\n           \"\"\"Generate sequences\"\"\"\n           \n           from vllm import SamplingParams\n           \n           # Build sampling parameters\n           sampling_params = SamplingParams(\n               n=self.config.rollout.n,  # GRPO group size\n               temperature=self.config.rollout.temperature,\n               top_p=self.config.rollout.top_p,\n               max_tokens=self.config.data.max_response_length,\n           )\n           \n           # Prepare prompts\n           prompts = batch[\"prompts\"]  # List[str] or List[List[int]]\n           \n           # Generate\n           outputs = self.llm.generate(prompts, sampling_params)\n           \n           # Process outputs\n           all_responses = []\n           all_response_ids = []\n           \n           for output in outputs:\n               for completion in output.outputs:\n                   all_responses.append(completion.text)\n                   all_response_ids.append(completion.token_ids)\n           \n           # Update batch\n           batch[\"responses\"] = all_responses\n           batch[\"response_ids\"] = torch.tensor(all_response_ids)\n           batch[\"metrics\"] = {\n               \"rollout/avg_response_length\": np.mean([len(r) for r in all_response_ids])\n           }\n           \n           return batch\n\n7.5 Sharding Manager\n--------------------\n\nSharding Manager is responsible for weight synchronization between Actor and Rollout:\n\n.. code-block:: python\n   :caption: siirl/engine/sharding_manager/fsdp_vllm.py (simplified)\n\n   class FSDPVLLMShardingManager:\n       \"\"\"Weight synchronization between FSDP Actor and vLLM Rollout\"\"\"\n       \n       def __init__(self, actor: FSDPActor, rollout: VLLMRollout, process_group: ProcessGroup):\n           self.actor = actor\n           self.rollout = rollout\n           self.process_group = process_group\n       \n       def sync_weights_actor_to_rollout(self):\n           \"\"\"Sync Actor weights to Rollout\"\"\"\n           \n           # 1. Gather full weights from FSDP\n           with FSDP.state_dict_type(\n               self.actor.model,\n               StateDictType.FULL_STATE_DICT,\n               FullStateDictConfig(offload_to_cpu=True, rank0_only=True)\n           ):\n               state_dict = self.actor.model.state_dict()\n           \n           # 2. Broadcast to all ranks\n           dist.broadcast_object_list([state_dict], src=0, group=self.process_group)\n           \n           # 3. Update vLLM model weights\n           self.rollout.load_weights(state_dict)\n\n----\n\n.. _sec8_core_algorithms:\n\n8. Core Algorithm Implementation\n================================\n\n8.1 Advantage Estimators\n------------------------\n\nsiiRL supports multiple advantage estimation methods:\n\n.. code-block:: python\n   :caption: siirl/dag_worker/core_algos.py\n\n   # Registry decorator\n   ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {}\n   \n   def register_adv_est(name_or_enum: str | AdvantageEstimator):\n       \"\"\"Register an advantage estimator\"\"\"\n       def decorator(fn):\n           name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum\n           ADV_ESTIMATOR_REGISTRY[name] = fn\n           return fn\n       return decorator\n   \n   @register_adv_est(AdvantageEstimator.GAE)\n   def compute_gae_advantage_return(\n       token_level_rewards: torch.Tensor,  # (bs, response_length)\n       values: torch.Tensor,               # (bs, response_length)\n       response_mask: torch.Tensor,        # (bs, response_length)\n       gamma: float,\n       lam: float,\n   ):\n       \"\"\"GAE (Generalized Advantage Estimation) for PPO\"\"\"\n       with torch.no_grad():\n           nextvalues = 0\n           lastgaelam = 0\n           advantages_reversed = []\n           gen_len = token_level_rewards.shape[-1]\n           \n           for t in reversed(range(gen_len)):\n               delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]\n               lastgaelam_ = delta + gamma * lam * lastgaelam\n               \n               # Skip padding tokens\n               nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues\n               lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam\n               \n               advantages_reversed.append(lastgaelam)\n           \n           advantages = torch.stack(advantages_reversed[::-1], dim=1)\n           returns = advantages + values\n           advantages = masked_whiten(advantages, response_mask)\n       \n       return advantages, returns\n   \n   @register_adv_est(AdvantageEstimator.GRPO)\n   def compute_grpo_outcome_advantage(\n       token_level_rewards: torch.Tensor,  # (bs, response_length)\n       response_mask: torch.Tensor,        # (bs, response_length)\n       index: np.ndarray,                  # Index for grouping\n       epsilon: float = 1e-6,\n       norm_adv_by_std_in_grpo: bool = True,\n   ):\n       \"\"\"GRPO (Group Relative Policy Optimization)\"\"\"\n       scores = token_level_rewards.sum(dim=-1)  # Sequence-level rewards\n       \n       id2score = defaultdict(list)\n       id2mean = {}\n       id2std = {}\n       \n       with torch.no_grad():\n           bsz = scores.shape[0]\n           \n           # Group by prompt\n           for i in range(bsz):\n               idx_key = int(index[i].item()) if isinstance(index[i], torch.Tensor) else int(index[i])\n               id2score[idx_key].append(scores[i])\n           \n           # Compute group mean and std\n           for idx in id2score:\n               if len(id2score[idx]) == 1:\n                   id2mean[idx] = torch.tensor(0.0)\n                   id2std[idx] = torch.tensor(1.0)\n               elif len(id2score[idx]) > 1:\n                   scores_tensor = torch.stack(id2score[idx])\n                   id2mean[idx] = torch.mean(scores_tensor)\n                   id2std[idx] = torch.std(scores_tensor)\n           \n           # Normalize\n           for i in range(bsz):\n               idx_key = int(index[i].item()) if isinstance(index[i], torch.Tensor) else int(index[i])\n               if norm_adv_by_std_in_grpo:\n                   scores[i] = (scores[i] - id2mean[idx_key]) / (id2std[idx_key] + epsilon)\n               else:  # Dr.GRPO\n                   scores[i] = scores[i] - id2mean[idx_key]\n           \n           scores = scores.unsqueeze(-1) * response_mask\n       \n       return scores, scores\n\n8.2 Policy Loss Functions\n-------------------------\n\nsiiRL supports multiple policy loss functions:\n\n.. code-block:: python\n   :caption: siirl/dag_worker/core_algos.py\n\n   POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {}\n   \n   def register_policy_loss(name: str):\n       \"\"\"Register a policy loss function\"\"\"\n       def decorator(func: PolicyLossFn) -> PolicyLossFn:\n           POLICY_LOSS_REGISTRY[name] = func\n           return func\n       return decorator\n   \n   @register_policy_loss(\"vanilla\")\n   def compute_policy_loss_vanilla(\n       old_log_prob: torch.Tensor,\n       log_prob: torch.Tensor,\n       advantages: torch.Tensor,\n       response_mask: torch.Tensor,\n       loss_agg_mode: str = \"token-mean\",\n       config: Optional[ActorArguments] = None,\n       rollout_is_weights: torch.Tensor | None = None,\n   ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n       \"\"\"Standard PPO policy loss (dual-clip)\"\"\"\n       \n       clip_ratio = config.clip_ratio\n       clip_ratio_low = config.clip_ratio_low or clip_ratio\n       clip_ratio_high = config.clip_ratio_high or clip_ratio\n       clip_ratio_c = config.clip_ratio_c\n       \n       negative_approx_kl = log_prob - old_log_prob\n       ratio = torch.exp(negative_approx_kl)\n       ppo_kl = masked_mean(-negative_approx_kl, response_mask)\n       \n       # Standard PPO clipping\n       pg_losses1 = -advantages * ratio\n       pg_losses2 = -advantages * torch.clamp(ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)\n       clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)\n       pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)\n       \n       # Dual clipping (negative advantage scenario)\n       pg_losses3 = -advantages * clip_ratio_c\n       clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)\n       pg_clipfrac_lower = masked_mean(\n           torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask\n       )\n       \n       pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)\n       \n       # Apply importance weights\n       if rollout_is_weights is not None:\n           pg_losses = pg_losses * rollout_is_weights\n       \n       pg_loss = agg_loss(pg_losses, response_mask, loss_agg_mode)\n       \n       return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower\n   \n   @register_policy_loss(\"cpgd\")\n   def compute_policy_loss_cpgd(...):\n       \"\"\"CPGD policy loss (direct log_prob clipping)\"\"\"\n       ...\n   \n   @register_policy_loss(\"gspo\")\n   def compute_policy_loss_gspo(...):\n       \"\"\"GSPO policy loss (sequence-level importance ratio)\"\"\"\n       ...\n   \n   @register_policy_loss(\"gpg\")\n   def compute_policy_loss_gpg(...):\n       \"\"\"GPG policy loss (REINFORCE style)\"\"\"\n       ...\n\n8.3 KL Penalty\n--------------\n\n.. code-block:: python\n\n   class AdaptiveKLController:\n       \"\"\"Adaptive KL Controller\"\"\"\n       \n       def __init__(self, init_kl_coef, target_kl, horizon):\n           self.value = init_kl_coef\n           self.target = target_kl\n           self.horizon = horizon\n       \n       def update(self, current_kl, n_steps):\n           proportional_error = np.clip(current_kl / self.target - 1, -0.2, 0.2)\n           mult = 1 + proportional_error * n_steps / self.horizon\n           self.value *= mult\n   \n   def apply_kl_penalty(data: TensorDict, kl_ctrl, kl_penalty=\"kl\"):\n       \"\"\"Apply KL penalty to token-level rewards\"\"\"\n       \n       kld = kl_penalty_fn(data[\"old_log_probs\"], data[\"ref_log_prob\"], kl_penalty)\n       kld = kld * data[\"response_mask\"]\n       beta = kl_ctrl.value\n       \n       data[\"token_level_rewards\"] = data[\"token_level_scores\"] - beta * kld\n       \n       current_kl = masked_mean(kld, data[\"response_mask\"]).item()\n       kl_ctrl.update(current_kl=current_kl, n_steps=data.batch_size[0])\n       \n       return data, {\"actor/reward_kl_penalty\": current_kl, \"actor/kl_coef\": beta}\n\n----\n\n.. _sec9_execution_flow:\n\n9. Complete Execution Flow\n==========================\n\n9.1 GRPO Training Flow\n----------------------\n\nUsing GRPO as an example, showing the complete training flow:\n\n::\n\n   ┌──────────────────────────────────────────────────────────────────────────────┐\n   │                          GRPO Single Step Training Flow                       │\n   └──────────────────────────────────────────────────────────────────────────────┘\n   \n   [1. Data Loading]\n       │\n       │  DataLoader.run() → batch (prompts, attention_mask, ...)\n       │\n       ▼\n   [2. Rollout Generation]  ───────────────────────────────────────────────────────\n       │\n       │  DAGWorker.generate()\n       │      │\n       │      ├── Prepare generation batch\n       │      ├── rollout_worker.generate_sequences(batch)\n       │      │       │\n       │      │       ├── vLLM/SGLang/HF inference\n       │      │       └── Return responses, response_ids\n       │      │\n       │      └── Update batch: responses, response_mask\n       │\n       │  Output: batch with responses (bs * n_samples, seq_len)\n       │\n       ▼\n   [3. Reward Computation]  ──────────────────────────────────────────────────────\n       │\n       │  DAGWorker.compute_reward()\n       │      │\n       │      ├── reward_fn.score(batch) → token_level_scores\n       │      │\n       │      ├── (Optional) Apply KL penalty:\n       │      │       kl = old_log_prob - ref_log_prob\n       │      │       token_level_rewards = token_level_scores - β * kl\n       │      │\n       │      └── Otherwise: token_level_rewards = token_level_scores\n       │\n       │  Output: batch with token_level_rewards\n       │\n       ▼\n   [4. Advantage Computation]  ───────────────────────────────────────────────────\n       │\n       │  DAGWorker.compute_advantage()\n       │      │\n       │      └── compute_grpo_outcome_advantage()\n       │              │\n       │              ├── Compute sequence-level scores: scores = rewards.sum(dim=-1)\n       │              ├── Group by prompt\n       │              ├── Compute group mean and std\n       │              └── Normalize: (scores - mean) / std\n       │\n       │  Output: batch with advantages\n       │\n       ▼\n   [5. Actor Forward]  ───────────────────────────────────────────────────────────\n       │\n       │  DAGWorker.compute_old_log_prob()\n       │      │\n       │      └── actor_worker.compute_log_prob(batch)\n       │              │\n       │              ├── Forward pass (no_grad)\n       │              └── Compute old_log_probs\n       │\n       │  Output: batch with old_log_probs\n       │\n       ▼\n   [6. Reference Forward]  ───────────────────────────────────────────────────────\n       │\n       │  DAGWorker.compute_ref_log_prob()\n       │      │\n       │      └── reference_worker.compute_ref_log_prob(batch)\n       │              │\n       │              ├── Forward pass (no_grad)\n       │              └── Compute ref_log_prob\n       │\n       │  Output: batch with ref_log_prob\n       │\n       ▼\n   [7. Actor Training]  ──────────────────────────────────────────────────────────\n       │\n       │  DAGWorker.train_actor()\n       │      │\n       │      └── actor_worker.update_actor(batch)\n       │              │\n       │              ├── for _ in range(ppo_epochs):\n       │              │       │\n       │              │       ├── Forward pass → log_probs\n       │              │       ├── Compute policy loss:\n       │              │       │       pg_loss = -advantages * clipped_ratio\n       │              │       ├── Compute entropy loss\n       │              │       ├── Total loss = pg_loss - entropy_coef * entropy\n       │              │       ├── Backward pass\n       │              │       └── Optimizer step\n       │              │\n       │              └── Return metrics\n       │\n       │  Output: batch with metrics\n       │\n       ▼\n   [8. Sync Weights]\n       │\n       │  sharding_manager.sync_weights_actor_to_rollout()\n       │\n       ▼\n   [Done: Continue to next step]\n\n9.2 PPO Training Flow\n---------------------\n\nPPO adds Critic model and GAE computation compared to GRPO:\n\n::\n\n   GRPO flow + the following additional steps:\n   \n   [3.5. Value Computation] (After Reward, before Advantage)\n       │\n       │  DAGWorker.compute_value()\n       │      │\n       │      └── critic_worker.compute_values(batch)\n       │              │\n       │              ├── Forward pass (no_grad)\n       │              └── Compute values\n       │\n       │  Output: batch with values\n   \n   [4. Advantage Computation] (Uses GAE instead of GRPO)\n       │\n       │  compute_gae_advantage_return()\n       │      │\n       │      ├── Reverse iterate through response tokens\n       │      ├── Compute TD-error: δ = r + γV(s') - V(s)\n       │      └── GAE: A = δ + γλA'\n   \n   [7.5. Critic Training] (After Actor training)\n       │\n       │  DAGWorker.train_critic()\n       │      │\n       │      └── critic_worker.update_critic(batch)\n       │              │\n       │              ├── Forward pass → vpreds\n       │              ├── Compute Value loss:\n       │              │       vf_loss = clipped_mse(vpreds, returns)\n       │              ├── Backward pass\n       │              └── Optimizer step\n\n----\n\n.. _sec10_configuration:\n\n10. Configuration Parameters\n============================\n\n10.1 Configuration File Structure\n---------------------------------\n\nsiiRL uses Hydra for configuration management, with main configuration groups:\n\n.. code-block:: yaml\n   :caption: Configuration File Structure\n\n   # algorithm: Algorithm configuration\n   algorithm:\n     adv_estimator: grpo        # grpo/gae/cpgd/gspo\n     workflow_type: DEFAULT     # DEFAULT/DAPO/EMBODIED\n     gamma: 1.0                 # Discount factor\n     lam: 0.95                  # GAE lambda\n     use_kl_in_reward: false    # Whether to use KL penalty in reward\n     norm_adv_by_std_in_grpo: true\n     \n     kl_ctrl:\n       type: fixed              # fixed/adaptive\n       kl_coef: 0.001\n   \n   # data: Data configuration\n   data:\n     train_files: /path/to/train.parquet\n     train_batch_size: 512\n     max_prompt_length: 2048\n     max_response_length: 4096\n     num_loader_workers: 4\n   \n   # actor_rollout_ref: Model configuration\n   actor_rollout_ref:\n     model:\n       path: /path/to/model\n       dtype: bfloat16\n       trust_remote_code: true\n     \n     actor:\n       strategy: fsdp           # fsdp/megatron\n       clip_ratio: 0.2\n       entropy_coef: 0.01\n       ppo_epochs: 1\n       ppo_mini_batch_size: 256\n       max_grad_norm: 1.0\n       \n       optim:\n         lr: 1e-6\n         weight_decay: 0.01\n         scheduler: cosine_with_warmup\n         warmup_ratio: 0.1\n     \n     rollout:\n       name: vllm                # vllm/sglang/hf\n       tensor_model_parallel_size: 2\n       n: 8                      # GRPO group size\n       temperature: 1.0\n       top_p: 1.0\n       mode: sync                # sync/async\n   \n   # trainer: Trainer configuration\n   trainer:\n     n_gpus_per_node: 8\n     nnodes: 1\n     total_epochs: 30\n     save_freq: 10\n     test_freq: 5\n     val_before_train: false\n     critic_warmup: 0\n     \n     project_name: my_project\n     experiment_name: grpo_training\n     logger: wandb             # wandb/tensorboard/console\n   \n   # dag: DAG configuration\n   dag:\n     custom_pipeline_fn: null   # Custom Pipeline function path\n     enable_perf: false\n     backend_threshold: 32\n\n10.2 Key Parameter Descriptions\n-------------------------------\n\n.. list-table:: Key Configuration Parameters\n   :header-rows: 1\n   :widths: 30 15 55\n\n   * - Parameter\n     - Default\n     - Description\n   * - ``algorithm.adv_estimator``\n     - grpo\n     - Advantage estimator (grpo/gae/cpgd/gspo)\n   * - ``algorithm.workflow_type``\n     - DEFAULT\n     - Workflow type (DEFAULT/DAPO/EMBODIED)\n   * - ``data.train_batch_size``\n     - 512\n     - Global training batch size\n   * - ``actor_rollout_ref.rollout.n``\n     - 8\n     - GRPO samples per prompt\n   * - ``actor_rollout_ref.actor.clip_ratio``\n     - 0.2\n     - PPO clipping ratio\n   * - ``actor_rollout_ref.actor.ppo_epochs``\n     - 1\n     - PPO epochs per training step\n   * - ``actor_rollout_ref.rollout.tensor_model_parallel_size``\n     - 1\n     - Rollout TP size\n   * - ``trainer.save_freq``\n     - 10\n     - Checkpoint save frequency (steps)\n   * - ``trainer.test_freq``\n     - 5\n     - Validation frequency (steps)\n\n10.3 How to Add New Configuration Items\n---------------------------------------\n\nsiiRL uses Python dataclasses for configuration management. Here's how to add new configuration items:\n\n**Step 1: Identify the Configuration Group**\n\nConfiguration is organized into the following groups in ``siirl/params/``:\n\n::\n\n   siirl/params/\n   ├── __init__.py              # Exports all argument classes\n   ├── training_args.py         # TrainingArguments, SiiRLArguments (root)\n   ├── model_args.py            # ActorArguments, RolloutArguments, AlgorithmArguments, etc.\n   ├── data_args.py             # DataArguments\n   ├── dag_args.py              # DagArguments\n   ├── profiler_args.py         # ProfilerArguments\n   └── embodied_args.py         # EmbodiedArguments\n\n**Step 2: Add a New Field to the Appropriate Dataclass**\n\nExample: Adding a new ``max_retry_count`` field to ``TrainingArguments``:\n\n.. code-block:: python\n   :caption: siirl/params/training_args.py\n\n   from dataclasses import dataclass, field\n   from typing import Optional\n   \n   @dataclass\n   class TrainingArguments:\n       # Existing fields...\n       total_epochs: int = field(default=30, metadata={\"help\": \"Total training epochs\"})\n       save_freq: int = field(default=-1, metadata={\"help\": \"Checkpoint frequency\"})\n       \n       # Add your new field here\n       max_retry_count: int = field(\n           default=3,\n           metadata={\"help\": \"Maximum retry count for failed training steps\"}\n       )\n\n**Step 3: Add a New Argument Group (if needed)**\n\nIf adding a completely new category, create a new dataclass and register it in ``SiiRLArguments``:\n\n.. code-block:: python\n   :caption: siirl/params/my_custom_args.py (new file)\n\n   from dataclasses import dataclass, field\n   from typing import Dict, Any\n   \n   @dataclass\n   class MyCustomArguments:\n       \"\"\"Custom arguments for new feature.\"\"\"\n       \n       enable_feature: bool = field(\n           default=False,\n           metadata={\"help\": \"Enable the custom feature\"}\n       )\n       feature_threshold: float = field(\n           default=0.5,\n           metadata={\"help\": \"Threshold for the custom feature\"}\n       )\n       feature_config: Dict[str, Any] = field(\n           default_factory=dict,\n           metadata={\"help\": \"Additional configuration for the feature\"}\n       )\n       \n       def to_dict(self) -> Dict[str, Any]:\n           from dataclasses import asdict\n           return asdict(self)\n\nThen register in ``SiiRLArguments``:\n\n.. code-block:: python\n   :caption: siirl/params/training_args.py\n\n   from siirl.params.my_custom_args import MyCustomArguments\n   \n   @dataclass\n   class SiiRLArguments:\n       data: DataArguments = field(default_factory=DataArguments)\n       actor_rollout_ref: ActorRolloutRefArguments = field(default_factory=ActorRolloutRefArguments)\n       # ... existing fields ...\n       \n       # Add your new argument group\n       my_custom: MyCustomArguments = field(default_factory=MyCustomArguments)\n\n**Step 4: Export in __init__.py**\n\n.. code-block:: python\n   :caption: siirl/params/__init__.py\n\n   from .my_custom_args import MyCustomArguments\n   \n   __all__ = [\n       # ... existing exports ...\n       \"MyCustomArguments\",\n   ]\n\n**Step 5: Use in YAML Configuration**\n\nAfter adding the new fields, you can use them in your YAML configuration:\n\n.. code-block:: yaml\n   :caption: config.yaml\n\n   trainer:\n     total_epochs: 30\n     save_freq: 10\n     max_retry_count: 5  # Your new field\n   \n   my_custom:  # Your new argument group\n     enable_feature: true\n     feature_threshold: 0.7\n     feature_config:\n       key1: value1\n       key2: value2\n\n**Step 6: Access in Code**\n\n.. code-block:: python\n\n   def my_function(config: SiiRLArguments):\n       # Access top-level trainer config\n       max_retry = config.trainer.max_retry_count\n       \n       # Access your custom argument group\n       if config.my_custom.enable_feature:\n           threshold = config.my_custom.feature_threshold\n           extra_config = config.my_custom.feature_config\n\n**Configuration Hierarchy**:\n\n::\n\n   SiiRLArguments (root)\n   ├── data: DataArguments\n   │   ├── train_files\n   │   ├── train_batch_size\n   │   └── ...\n   ├── actor_rollout_ref: ActorRolloutRefArguments\n   │   ├── model: ModelArguments\n   │   ├── actor: ActorArguments\n   │   │   ├── strategy\n   │   │   ├── clip_ratio\n   │   │   ├── optim: OptimizerArguments\n   │   │   └── ...\n   │   ├── rollout: RolloutArguments\n   │   └── ref: RefArguments\n   ├── critic: CriticArguments\n   ├── reward_model: RewardModelArguments\n   ├── algorithm: AlgorithmArguments\n   │   ├── adv_estimator\n   │   ├── workflow_type\n   │   └── kl_ctrl: KLCtrlArguments\n   ├── trainer: TrainingArguments\n   ├── custom_reward_function: CustomRewardArguments\n   ├── dag: DagArguments\n   └── profiler: ProfilerArguments\n\n----\n\n.. _sec11_extension_guide:\n\n11. Extension Guide\n===================\n\n11.1 Custom Pipeline\n--------------------\n\nUsers can define custom Pipelines:\n\n.. code-block:: python\n   :caption: examples/custom_pipeline_example/custom_pipeline.py\n\n   from siirl.execution.dag.pipeline import Pipeline\n   from siirl.execution.dag.node import NodeType, NodeRole\n   from siirl.execution.dag.task_graph import TaskGraph\n   \n   def my_custom_pipeline() -> TaskGraph:\n       \"\"\"Custom training pipeline\"\"\"\n       pipeline = Pipeline(\"my_custom_pipeline\", \"My custom RL workflow\")\n       \n       # Add custom nodes\n       pipeline.add_node(\n           \"rollout_actor\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.generate\",\n           deps=[],\n           node_type=NodeType.MODEL_INFERENCE,\n           node_role=NodeRole.ROLLOUT\n       ).add_node(\n           \"custom_reward\",\n           func=\"my_module.custom_reward:compute_custom_reward\",  # Custom function\n           deps=[\"rollout_actor\"],\n           node_type=NodeType.COMPUTE,\n           node_role=NodeRole.REWARD\n       ).add_node(\n           \"calculate_advantages\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_advantage\",\n           deps=[\"custom_reward\"],\n           node_type=NodeType.COMPUTE,\n           node_role=NodeRole.ADVANTAGE\n       ).add_node(\n           \"actor_train\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.train_actor\",\n           deps=[\"calculate_advantages\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.ACTOR\n       )\n       \n       return pipeline.build()\n\nSpecify in configuration:\n\n.. code-block:: yaml\n\n   dag:\n     custom_pipeline_fn: \"my_module.custom_pipeline:my_custom_pipeline\"\n\n11.2 Custom Reward Function\n---------------------------\n\n.. code-block:: python\n   :caption: siirl/user_interface/rewards_interface/custom_reward.py\n\n   from siirl.dag_worker.data_structures import NodeOutput\n   from tensordict import TensorDict\n   \n   def compute_custom_reward(batch: TensorDict, config, **kwargs) -> NodeOutput:\n       \"\"\"Custom Reward computation function\"\"\"\n       \n       # Get generated responses\n       responses = batch[\"responses\"]\n       prompts = batch[\"prompts\"]\n       \n       # Custom reward logic\n       rewards = []\n       for prompt, response in zip(prompts, responses):\n           # Implement your reward function\n           score = my_scoring_function(prompt, response)\n           rewards.append(score)\n       \n       # Convert to token-level rewards\n       token_level_rewards = torch.zeros_like(batch[\"attention_mask\"])\n       for i, score in enumerate(rewards):\n           # Assign sequence-level reward to last token\n           token_level_rewards[i, -1] = score\n       \n       batch[\"token_level_scores\"] = token_level_rewards\n       batch[\"token_level_rewards\"] = token_level_rewards\n       \n       metrics = {\"reward/mean_score\": np.mean(rewards)}\n       \n       return NodeOutput(batch=batch, metrics=metrics)\n\n11.3 Custom Advantage Estimator\n-------------------------------\n\n.. code-block:: python\n   :caption: Registering Custom Advantage Estimator\n\n   from siirl.dag_worker.core_algos import register_adv_est\n   from siirl.execution.scheduler.enums import AdvantageEstimator\n   \n   @register_adv_est(\"my_custom_adv\")  # Or use enum\n   def compute_my_custom_advantage(\n       token_level_rewards: torch.Tensor,\n       response_mask: torch.Tensor,\n       **kwargs\n   ):\n       \"\"\"Custom Advantage estimation\"\"\"\n       \n       # Implement your advantage estimation logic\n       advantages = ...\n       returns = ...\n       \n       return advantages, returns\n\n11.4 Custom Policy Loss\n-----------------------\n\n.. code-block:: python\n   :caption: Registering Custom Policy Loss\n\n   from siirl.dag_worker.core_algos import register_policy_loss\n   \n   @register_policy_loss(\"my_custom_loss\")\n   def compute_my_custom_policy_loss(\n       old_log_prob: torch.Tensor,\n       log_prob: torch.Tensor,\n       advantages: torch.Tensor,\n       response_mask: torch.Tensor,\n       loss_agg_mode: str = \"token-mean\",\n       config = None,\n       rollout_is_weights = None,\n   ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n       \"\"\"Custom policy loss\"\"\"\n       \n       # Implement your policy loss logic\n       pg_loss = ...\n       pg_clipfrac = ...\n       ppo_kl = ...\n       pg_clipfrac_lower = ...\n       \n       return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower\n\n----\n\nAppendix A: Code File Navigation\n================================\n\n::\n\n   siirl/\n   ├── main_dag.py                           # Main entry point\n   ├── dag_worker/                           # DAG Worker module\n   │   ├── dagworker.py                      # Core Worker class (~1320 lines)\n   │   ├── core_algos.py                     # RL algorithm implementations\n   │   ├── dag_utils.py                      # Utility functions\n   │   ├── checkpoint_manager.py             # Checkpoint management\n   │   ├── validator.py                      # Validation logic\n   │   ├── metrics_collector.py              # Metrics collection\n   │   └── data_structures.py                # Data structure definitions\n   ├── execution/                            # Execution engine\n   │   ├── dag/                              # DAG definitions\n   │   │   ├── __init__.py                   # Module exports\n   │   │   ├── task_graph.py                 # TaskGraph class\n   │   │   ├── node.py                       # Node/NodeType/NodeRole/NodeStatus classes\n   │   │   ├── pipeline.py                   # Pipeline Builder API\n   │   │   ├── builtin_pipelines.py          # Built-in Pipelines (GRPO/PPO/DAPO/Embodied)\n   │   │   └── task_loader.py                # Graph splitting utilities\n   │   ├── scheduler/                        # Scheduler\n   │   │   ├── task_scheduler.py             # Task scheduling\n   │   │   ├── process_group_manager.py      # Process group management\n   │   │   ├── launch.py                     # Ray launcher\n   │   │   └── enums.py                      # Enum definitions\n   │   └── metric_worker/                    # Distributed metrics\n   │       └── metric_worker.py              # MetricWorker Actor\n   ├── engine/                               # Model execution engine\n   │   ├── actor/                            # Actor Workers\n   │   ├── critic/                           # Critic Workers\n   │   ├── rollout/                          # Rollout Workers (vLLM/SGLang/HF)\n   │   ├── reward_model/                     # Reward Model Workers\n   │   ├── reward_manager/                   # Reward Managers (naive/parallel/dapo/embodied)\n   │   └── sharding_manager/                 # Weight sharding management (FSDP/Megatron)\n   ├── data_coordinator/                     # Data coordinator\n   │   ├── data_buffer.py                    # DataCoordinator Actor\n   │   ├── dataloader/                       # Distributed DataLoader\n   │   ├── protocol.py                       # Data protocol\n   │   └── sample.py                         # Sample/SampleInfo\n   ├── user_interface/                       # User extension interfaces\n   │   ├── filter_interface/                 # Filtering plugins\n   │   │   ├── dapo.py                       # DAPO dynamic sampling\n   │   │   └── embodied.py                   # Embodied dynamic sampling\n   │   └── rewards_interface/                # Custom reward interfaces\n   ├── params/                               # Configuration parameters\n   │   ├── __init__.py                       # SiiRLArguments\n   │   ├── parser.py                         # Configuration parser\n   │   ├── data_args.py                      # Data parameters\n   │   ├── model_args.py                     # Model parameters\n   │   └── training_args.py                  # Training parameters\n   └── utils/                                # Utilities\n       ├── checkpoint/                       # Checkpoint utilities\n       ├── logger/                           # Logging utilities\n       ├── model_utils/                      # Model utilities\n       └── reward_score/                     # Reward computation\n\n----\n\nSummary\n=======\n\nThis document provides a comprehensive guide to siiRL's architecture implementation, including:\n\n1. **Architecture Overview**: siiRL's position in distributed RL systems and core advantages\n2. **DistFlow Design Philosophy**: Fully distributed, multi-controller paradigm design\n3. **Program Entry**: main_dag.py and MainRunner startup flow\n4. **DAG Planner**: Pipeline API, TaskGraph, TaskScheduler implementation\n5. **DAG Worker**: Core execution unit initialization, training loop, node execution\n6. **Data Coordinator**: Distributed data management and length balancing algorithm\n7. **Engine**: Actor/Critic/Rollout/Reference/Reward Worker implementations\n8. **Core Algorithms**: Advantage estimators, Policy Loss function implementations\n9. **Execution Flow**: Complete GRPO/PPO training flows\n10. **Configuration**: Key configuration parameters explained\n11. **Extension Guide**: Custom Pipeline, Reward, Advantage, Policy Loss\n\nBy reading this document, readers should gain a deep understanding of siiRL's design philosophy and implementation details, providing a solid foundation for future development, optimization, and extension work.\n\n**References**:\n\n- siiRL Paper: `DistFlow: A Fully Distributed RL Framework for Scalable and Efficient LLM Post-Training <https://arxiv.org/abs/2507.13833>`__\n- Official Documentation: `https://siirl.readthedocs.io/ <https://siirl.readthedocs.io/>`__\n- GitHub Repository: `https://github.com/sii-research/siiRL <https://github.com/sii-research/siiRL>`__\n"
  },
  {
    "path": "docs/programming_guide/srpo_code_explained.rst",
    "content": "SRPO Code Implementation Explained\n==================================\n\nThis document provides a comprehensive guide to understanding the SRPO (Self-Referential Policy Optimization) algorithm implementation in siiRL. SRPO is designed for training Vision-Language-Action (VLA) models in embodied AI scenarios.\n\n.. note::\n\n   **Paper Reference**: `SRPO: Self-Referential Policy Optimization for Vision-Language-Action Models <https://arxiv.org/pdf/2511.15605>`_\n\nOverview: What is SRPO?\n-----------------------\n\n**Self-Referential Policy Optimization (SRPO) for Vision-Language-Action Models** is a novel VLA-RL framework. SRPO eliminates the need for external demonstrations or manual reward engineering by leveraging successful trajectories generated by the model within the current training batch as self-references. This enables us to assign progress-based rewards to failed attempts.\n\nA core innovation is the use of **latent world representations** (V-JEPA) to robustly measure behavioral progress. Rather than relying on raw pixels or requiring domain-specific fine-tuning, we utilize compressed, transferable encodings from a world model's latent space. These representations naturally capture progress patterns across environments, making trajectory comparison accurate and generalizable.\n\nEmpirical evaluation on the LIBERO benchmark demonstrates SRPO's efficiency and effectiveness. Starting from a supervised baseline with a 48.9% success rate, SRPO achieves a 99.2% success rate on novel states within only 200 RL steps, representing a 103% relative improvement without any additional supervision. Furthermore, SRPO shows significant robustness on the LIBERO-Plus benchmark, achieving a 167% performance gain.\n\n**In siiRL, SRPO is implemented as the** ``embodied_srpo_pipeline`` **+** ``GRPO`` **advantage estimator.**\n\nCode Architecture Overview\n--------------------------\n\n.. code-block:: text\n\n   siiRL/\n   ├── siirl/\n   │   ├── execution/\n   │   │   └── dag/\n   │   │       └── builtin_pipelines.py         # embodied_srpo_pipeline() definition\n   │   ├── user_interface/\n   │   │   └── filter_interface/\n   │   │       └── embodied.py                  # embodied_local_rank_sampling()\n   │   ├── engine/\n   │   │   ├── rollout/\n   │   │   │   └── embodied_rollout.py          # EmbodiedHFRollout class\n   │   │   └── actor/\n   │   │       └── embodied_actor.py            # RobDataParallelPPOActor class\n   │   ├── dag_worker/\n   │   │   └── core_algos.py                    # GRPO advantage & PPO loss\n   │   ├── environment/\n   │   │   └── embodied/\n   │   │       └── adapters/                    # LIBERO environment adapter\n   │   └── utils/\n   │       ├── reward_score/\n   │       │   └── embodied.py                  # compute_embodied_reward()\n   │       └── embodied/\n   │           └── video_emb.py                 # VideoEmbeddingModel (V-JEPA)\n   └── examples/\n       └── embodied_srpo_trainer/\n           └── run_openvla_oft_libero_*.sh      # Training scripts\n\n\nTraining Pipeline Definition\n----------------------------\n\nThe SRPO training pipeline is defined in ``siirl/execution/dag/builtin_pipelines.py`` using the Python Pipeline API:\n\n.. code-block:: python\n   :caption: siirl/execution/dag/builtin_pipelines.py - embodied_srpo_pipeline()\n\n   def embodied_srpo_pipeline() -> TaskGraph:\n       \"\"\"\n       Embodied AI GRPO training pipeline with data filtering and VJEPA-based reward computation.\n\n       Workflow:\n           1. rollout_actor: Environment rollout with embodied AI agent\n           2. dynaminc_sampling: Data verification and filtering\n           3. compute_reward: VJEPA-based reward computation\n           4. calculate_advantages: Calculate advantages (GRPO group-based)\n           5. actor_old_log_prob: Compute old actor log probabilities (forward only)\n           6. reference_log_prob: Compute reference model log probabilities\n           7. actor_train: Actor training with GRPO\n       \"\"\"\n       pipeline = Pipeline(\n           \"embodied_grpo_training_pipeline\",\n           \"Embodied AI GRPO training workflow with data filtering and VJEPA-based reward computation.\"\n       )\n\n       pipeline.add_node(\n           \"rollout_actor\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.generate\",\n           deps=[],\n           node_type=NodeType.MODEL_INFERENCE,\n           node_role=NodeRole.ROLLOUT\n       ).add_node(\n           \"dynaminc_sampling\",\n           func=\"siirl.user_interface.filter_interface.embodied.embodied_local_rank_sampling\",\n           deps=[\"rollout_actor\"],\n           node_type=NodeType.COMPUTE,\n           node_role=NodeRole.DYNAMIC_SAMPLING\n       ).add_node(\n           \"compute_reward\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_reward\",\n           deps=[\"dynaminc_sampling\"],\n           node_type=NodeType.COMPUTE,\n           node_role=NodeRole.REWARD\n       ).add_node(\n           \"calculate_advantages\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_advantage\",\n           deps=[\"compute_reward\"],\n           node_type=NodeType.COMPUTE,\n           node_role=NodeRole.ADVANTAGE\n       ).add_node(\n           \"actor_old_log_prob\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob\",\n           deps=[\"calculate_advantages\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.ACTOR,\n           only_forward_compute=True\n       ).add_node(\n           \"reference_log_prob\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob\",\n           deps=[\"actor_old_log_prob\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.REFERENCE\n       ).add_node(\n           \"actor_train\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.train_actor\",\n           deps=[\"reference_log_prob\"],\n           node_type=NodeType.MODEL_TRAIN,\n           node_role=NodeRole.ACTOR\n       )\n\n       return pipeline.build()\n\n\nData Flow Diagram\n~~~~~~~~~~~~~~~~~\n\n.. code-block:: text\n\n                              SRPO Training Pipeline Data Flow\n   ==============================================================================\n\n   DataLoader (task_id, trial_id)\n       |\n       v\n   +---------------------+\n   | rollout_actor       | EmbodiedHFRollout.generate_sequences()\n   | (MODEL_INFERENCE)   | -> VLA model + LIBERO environment interaction\n   +----------+----------+\n              | Output: {responses, input_ids, attention_mask, pixel_values,\n              |          complete, finish_step, vjepa_embedding, task_file_name}\n              v\n   +---------------------+\n   | dynamic_sampling    | embodied_local_rank_sampling()\n   | (COMPUTE)           | -> verify() + _filter_batch()\n   +----------+----------+ Filter by accuracy bounds & truncation\n              | Output: filtered batch (samples with 0.1 <= acc <= 0.9)\n              v\n   +---------------------+\n   | compute_reward      | compute_embodied_reward()\n   | (COMPUTE)           | -> VJEPA-based reward shaping\n   +----------+----------+ Success: reward=1.0, Failure: reward=sigmoid(distance)\n              | Output: + {token_level_scores, token_level_rewards}\n              v\n   +---------------------+\n   | calculate_advantages| compute_grpo_outcome_advantage()\n   | (COMPUTE)           | -> Group by prompt, normalize (score - mean) / std\n   +----------+----------+\n              | Output: + {advantages, returns}\n              v\n   +---------------------+\n   | actor_old_log_prob  | RobDataParallelPPOActor.compute_log_prob()\n   | (MODEL_TRAIN)       | -> Forward only, no gradient\n   | only_forward=True   |\n   +----------+----------+\n              | Output: + {old_log_probs}\n              v\n   +---------------------+\n   | reference_log_prob  | Reference model forward pass\n   | (MODEL_TRAIN)       |\n   +----------+----------+\n              | Output: + {ref_log_prob}\n              v\n   +---------------------+\n   | actor_train         | RobDataParallelPPOActor.update_policy()\n   | (MODEL_TRAIN)       | -> compute_policy_loss_vanilla() (PPO clipped loss)\n   +---------------------+\n              |\n              | Metrics: {pg_loss, pg_clipfrac, ppo_kl, grad_norm}\n              v\n   +---------------------+\n   | sync_weights        | ShardingManager (if needed)\n   +---------------------+\n\n\nCore Components Deep Dive\n-------------------------\n\n1. Rollout: Environment Interaction\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n**File**: ``siirl/engine/rollout/embodied_rollout.py``\n\n**Class**: ``EmbodiedHFRollout``\n\nThis is the core component that orchestrates the interaction between the VLA model and the simulation environment (LIBERO). It handles the complete episode generation process including action prediction, environment stepping, and visual embedding extraction.\n\nClass Initialization\n^^^^^^^^^^^^^^^^^^^^\n\n.. code-block:: python\n\n   class EmbodiedHFRollout(BaseRollout):\n       def __init__(self, module: nn.Module, config: ActorRolloutRefArguments):\n           self.model = module  # VLA model (e.g., OpenVLA-OFT)\n           self.config = config\n           \n           # Initialize V-JEPA embedding model for reward computation\n           self.embedding_model = VideoEmbeddingModel(\n               model_path=config.embodied.video_embedding_model_path,\n               img_size=config.embodied.embedding_img_size,\n               enable_fp16=config.embodied.embedding_enable_fp16\n           )\n           \n           # Initialize LIBERO environment adapter with parallel environments\n           self.num_workers = config.embodied.env.num_envs  # e.g., 16 parallel envs\n           self.adapter = LIBEROAdapter(\n               env_name=config.embodied.env.env_name,      # e.g., \"libero_goal\"\n               num_envs=self.num_workers,\n               max_steps=config.embodied.env.max_steps,    # e.g., 512\n               num_steps_wait=config.embodied.env.num_steps_wait,\n               model_family=config.embodied.env.model_family,\n               gpu_ids=[self._rank % self._num_gpus_per_node]\n           )\n\nMain Entry Point: generate_sequences()\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. code-block:: python\n\n   def generate_sequences(self, prompts):\n       \"\"\"\n       Main entry point for generating sequences.\n       Splits large batches into chunks that fit the number of parallel workers.\n       \"\"\"\n       total_batch_size = prompts.batch_size[0]\n       n_samples = prompts['n_samples'] if 'n_samples' in prompts else 1\n       \n       # Each prompt needs n_samples trajectories\n       batch_size_per_chunk = self.num_workers // n_samples\n       num_chunks = (total_batch_size + batch_size_per_chunk - 1) // batch_size_per_chunk\n       \n       all_chunk_outputs = []\n       for i in range(num_chunks):\n           chunk_prompts = prompts[start_idx:end_idx]\n           chunk_output = self._generate_chunk_rollout(chunk_prompts)\n           all_chunk_outputs.append(chunk_output)\n       \n       return torch.cat(all_chunk_outputs)\n\nEpisode Generation Loop: _generate_chunk_rollout()\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nThis is the heart of the embodied rollout - a step-by-step interaction loop between the VLA model and the environment.\n\n.. code-block:: python\n\n   def _generate_chunk_rollout(self, prompts):\n       \"\"\"Generate complete episodes for a chunk of tasks.\"\"\"\n       task_id = prompts['task_id']\n       trial_id = prompts['trial_id']\n       max_steps = self.config.embodied.env.max_steps\n       chunk_size = task_id.size(0)\n       \n       # Step 1: Reset all parallel environments\n       init_data_list = self.adapter._blocking_reset(\n           task_ids=task_id.reshape(-1).cpu().numpy().tolist(),\n           trial_ids=trial_id.reshape(-1).cpu().numpy().tolist(),\n       )\n       \n       # Collect initial observations\n       inputs = [self._obs_to_input(init_data['obs']) for init_data in init_data_list]\n       task_descriptions = [init_data[\"task_description\"] for init_data in init_data_list]\n       task_records = [{\"active\": d['active'], \"complete\": d['complete'],\n                        \"finish_step\": d['finish_step'], \"task_file_name\": d['task_file_name']}\n                       for d in init_data_list]\n       \n       # Step 2: Main interaction loop (up to max_steps)\n       step = 0\n       vla_history = []  # Store all step data for training\n       \n       while step < max_steps:\n           active_indices = [i for i, r in enumerate(task_records) if r['active']]\n           \n           # Step 2a: Process observations into VLA input format\n           vla_input = self.process_input(inputs, task_descriptions)\n           \n           # Step 2b: VLA model predicts actions\n           vla_output = self._generate_one_step(vla_input)\n           actions = vla_output[\"action\"]\n           \n           # Store step data for later training\n           vla_history.append({\n               \"responses\": vla_output[\"responses\"],\n               \"input_ids\": vla_output[\"input_ids\"],\n               \"attention_mask\": vla_output[\"attention_mask\"],\n               \"pixel_values\": vla_output[\"pixel_values\"],\n               \"action\": actions,\n               \"step\": step\n           })\n           \n           # Step 2c: Execute actions in environment\n           step_results_list = self.adapter._blocking_step({\n               \"indices\": active_indices,\n               \"actions\": actions,\n           })\n           \n           # Step 2d: Update observations and task status\n           for idx in active_indices:\n               result = step_results_list[idx]\n               inputs[idx] = self._obs_to_input(result['obs'])\n               task_records[idx]['active'] = result['active']\n               task_records[idx]['complete'] = result['complete']\n               task_records[idx]['finish_step'] = result['finish_step']\n           \n           step += self.config.embodied.action_chunks_len  # e.g., += 8\n       \n       # Step 3: Post-processing - Stack history and compute embeddings\n       batch = {}\n       for k in [\"responses\", \"input_ids\", \"attention_mask\", \"pixel_values\"]:\n           batch[k] = torch.stack([h[k] for h in vla_history], dim=1)\n       \n       batch[\"complete\"] = torch.tensor([r[\"complete\"] for r in task_records])\n       batch[\"finish_step\"] = torch.tensor([r[\"finish_step\"] for r in task_records])\n       \n       # Step 4: Extract V-JEPA embeddings for reward computation\n       batch_names, batch_frames = zip(*[(r['task_file_name'], all_video[r['task_file_name']])\n                                          for r in task_records])\n       vjepa_embeddings = self.embedding_model.get_embeddings(batch_names, batch_frames)\n       batch[\"vjepa_embedding\"] = torch.tensor(np.array(vjepa_embeddings))\n       \n       return TensorDict(batch, batch_size=chunk_size)\n\nSingle-Step Action Generation: _generate_one_step()\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. code-block:: python\n\n   @torch.no_grad()\n   def _generate_one_step(self, prompts: dict):\n       \"\"\"Generate one action chunk from VLA model.\"\"\"\n       if self.config.embodied.embodied_type == \"openvla-oft\":\n           # OpenVLA-OFT: Action Flow Transformer variant\n           with torch.autocast(device_type='cuda', dtype=torch.bfloat16):\n               actions, response = self.model.generate_action_verl(\n                   input_ids=idx,\n                   pixel_values=pixel_values,\n                   attention_mask=attention_mask,\n                   do_sample=do_sample,\n                   unnorm_key=self.config.embodied.unnorm_key,\n                   temperature=temperature,\n               )\n           # response shape: (batch_size, action_chunks_len * action_token_len)\n           \n       elif self.config.embodied.embodied_type == \"openvla\":\n           # Standard OpenVLA: Autoregressive token generation\n           output = self.model.generate(\n               input_ids=idx,\n               pixel_values=pixel_values,\n               attention_mask=attention_mask,\n               do_sample=do_sample,\n               max_new_tokens=response_length,\n               temperature=temperature,\n           )\n           # Decode action tokens to continuous actions\n           predicted_action_token_ids = output.sequences[:, prompt_length:]\n           discretized_actions = self.model.vocab_size - predicted_action_token_ids\n           normalized_actions = self.model.bin_centers[discretized_actions]\n       \n       return {\n           'responses': response,\n           'input_ids': idx,\n           'attention_mask': attention_mask,\n           'pixel_values': pixel_values,\n           'action': actions,\n       }\n\n**Key Output Fields**:\n\n.. list-table::\n   :header-rows: 1\n   :widths: 25 30 45\n\n   * - Field\n     - Shape\n     - Description\n   * - ``responses``\n     - ``(B, traj_len, action_token_len)``\n     - Action tokens (e.g., 7-dim: xyz + quat + gripper)\n   * - ``complete``\n     - ``(B,)``\n     - Boolean: task success flag\n   * - ``finish_step``\n     - ``(B,)``\n     - Integer: episode termination step\n   * - ``vjepa_embedding``\n     - ``(B, embed_dim)``\n     - V-JEPA visual features for reward computation\n\n\n2. Data Filtering (Dynamic Sampling)\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n**File**: ``siirl/user_interface/filter_interface/embodied.py``\n\n**Function**: ``embodied_local_rank_sampling()``\n\nThis step filters out \"too easy\" or \"too hard\" prompts based on the success rate within each group.\n\n.. code-block:: python\n\n   def embodied_local_rank_sampling(\n       config: SiiRLArguments,\n       batch: TensorDict,\n       **kwargs: Any,\n   ) -> NodeOutput:\n       \"\"\"\n       Performs verification, metric collection, and optional filtering on a batch.\n       \"\"\"\n       # Step 1: Verify the entire batch to get scores and enrich it with an 'acc' tensor.\n       _, reward_metrics, format_metrics, reward_format_metrics = verify(batch)\n\n       # Step 2: Conditionally filter the batch based on accuracy and truncation\n       embodied_sampling = config.algorithm.embodied_sampling\n       if embodied_sampling.filter_accuracy or embodied_sampling.filter_truncated:\n           n_samples = config.actor_rollout_ref.rollout.n\n           processed_batch = _filter_batch(batch, n_samples, config)\n       else:\n           processed_batch = batch\n\n       return NodeOutput(batch=processed_batch, metrics=sample_metrics)\n\n\n   def _filter_batch(batch: TensorDict, n_samples: int, config: SiiRLArguments) -> TensorDict:\n       \"\"\"\n       Filters a batch based on accuracy and truncation criteria.\n       Filtering is performed at the prompt level.\n       \"\"\"\n       num_prompts = len(batch) // n_samples\n       \n       # --- 1. Accuracy Filtering ---\n       if config.algorithm.embodied_sampling.filter_accuracy:\n           # Reshape flat accuracy tensor into (num_prompts, n_samples)\n           acc_matrix = batch[\"acc\"].reshape(num_prompts, n_samples)\n       # Calculate mean accuracy for each prompt\n       prompt_mean_acc = acc_matrix.mean(dim=-1)\n       \n           # Create a boolean mask for prompts within the desired accuracy bounds\n           accuracy_lower_bound = config.algorithm.embodied_sampling.accuracy_lower_bound\n           accuracy_upper_bound = config.algorithm.embodied_sampling.accuracy_upper_bound\n       acc_mask = (prompt_mean_acc >= accuracy_lower_bound) & (prompt_mean_acc <= accuracy_upper_bound)\n       else:\n           acc_mask = torch.ones(num_prompts, dtype=torch.bool, device=device)\n\n       # --- 2. Truncation Filtering ---\n       if config.algorithm.embodied_sampling.filter_truncated:\n           finish_steps = batch[\"finish_step\"].reshape(num_prompts, n_samples)\n           max_steps = config.actor_rollout_ref.embodied.env.max_steps\n           # A prompt is considered truncated if *any* of its samples reached max steps\n           has_truncated = (finish_steps >= max_steps).any(dim=-1)\n           trunc_mask = ~has_truncated\n       else:\n           trunc_mask = torch.ones(num_prompts, dtype=torch.bool, device=device)\n\n       # --- 3. Combine Masks and Apply Filter ---\n       combined_mask = acc_mask & trunc_mask\n       final_mask = combined_mask.repeat_interleave(n_samples)\n       filtered_batch = select_idxs(batch, final_mask)\n\n       return filtered_batch\n\n**Why Filter?**\n\n- **Too easy (acc > 0.9)**: All samples succeed → zero variance → zero advantage → no learning signal.\n- **Too hard (acc < 0.1)**: All samples fail → similar issue.\n- **Sweet spot (0.1 ≤ acc ≤ 0.9)**: Diverse outcomes → meaningful advantage estimates.\n\n\n3. Reward Computation (VJEPA-based)\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n**File**: ``siirl/utils/reward_score/embodied.py``\n\n**Function**: ``compute_embodied_reward()``\n\nThis is a key innovation of SRPO: using visual similarity to compute dense rewards for failed trajectories.\n\n.. code-block:: python\n\n   def compute_embodied_reward(\n       batch_data: TensorDict,\n       **kwargs: Any,\n   ) -> List[Dict[str, Any]]:\n       \"\"\"\n       Computes rewards based on VJEPA embeddings and task completion status.\n       \n       Reward Formula:\n       - Success: reward = 1.0\n       - Failure: reward = sigmoid(distance_to_success_cluster) ∈ [0, 0.6]\n       \"\"\"\n       # --- Step 1: Data Extraction and Pre-filtering ---\n       batch_size = batch_data[\"responses\"].size(0)\n       completes = np.array(batch_data[\"complete\"].tolist())\n       embeddings = batch_data[\"vjepa_embedding\"].cpu().numpy()\n       task_file_names = _tensor_to_str_list(batch_data[\"task_file_name\"])\n\n       # Pre-filtering: Identify invalid samples (all-zero embeddings)\n       zero_embedding_mask = np.all(embeddings == 0, axis=1)\n       valid_indices = np.where(~zero_embedding_mask)[0]\n\n       # --- Step 2: Initialize rewards ---\n       final_rewards = np.zeros(batch_size, dtype=float)\n       task_names = [_extract_task_name(name) for name in task_file_names]\n\n       # Group valid samples by task name\n       task_to_valid_indices = {}\n       for idx in valid_indices:\n           task_name = task_names[idx]\n           task_to_valid_indices.setdefault(task_name, []).append(idx)\n\n       # --- Step 3: Process each task group ---\n       for task_name, indices in task_to_valid_indices.items():\n           indices = np.array(indices)\n           task_completes = completes[indices]\n\n           success_indices = indices[task_completes]\n           fail_indices = indices[~task_completes]\n\n           # Success trajectories get reward = 1.0\n           final_rewards[success_indices] = 1.0\n           \n           if len(success_indices) == 0 or len(fail_indices) == 0:\n               continue\n           \n           # a. Cluster successful embeddings using DBSCAN\n           succ_embeddings = embeddings[success_indices]\n           scaler = StandardScaler()\n           scaled_succ_embeddings = scaler.fit_transform(succ_embeddings)\n           clustering = DBSCAN(eps=0.5, min_samples=2).fit(scaled_succ_embeddings)\n\n           cluster_centers = []\n           for label in set(clustering.labels_) - {-1}:\n               cluster_points = scaled_succ_embeddings[clustering.labels_ == label]\n               center = scaler.inverse_transform(cluster_points.mean(axis=0, keepdims=True)).flatten()\n               cluster_centers.append(center)\n\n           if not cluster_centers:\n               cluster_centers = [succ_embeddings.mean(axis=0)]\n           cluster_centers = np.array(cluster_centers)\n\n           # b. Compute distance from failed trajectories to nearest success cluster\n           fail_embeddings = embeddings[fail_indices]\n           distance_matrix = cdist(fail_embeddings, cluster_centers, \"euclidean\")\n           min_distances = distance_matrix.min(axis=1)\n           \n           # c. Map distance to reward via sigmoid\n           max_dist, min_dist = min_distances.max(), min_distances.min()\n           dist_range = max_dist - min_dist\n           if dist_range < 1e-6:\n               normalized_dists = np.full_like(min_distances, 0.5)\n           else:\n               normalized_dists = (min_distances - min_dist) / dist_range\n\n           sigmoid_steepness = 10.0\n           sigmoid_offset = 0.5\n           sigmoid_inputs = sigmoid_steepness * (sigmoid_offset - normalized_dists)\n           reward_values = 0.6 * special.expit(sigmoid_inputs)\n           \n           final_rewards[fail_indices] = reward_values\n       \n       return [{\"score\": final_rewards[i]} for i in range(batch_size)]\n\n**Reward Visualization**:\n\n.. code-block:: text\n\n   Reward\n     ^\n   1.0|  ●●● (Success)\n      |\n   0.6|  ─────────────────────  (Max for failure)\n      |      ╱\n      |     ╱  Sigmoid curve\n      |    ╱\n   0.0|───╱────────────────────▶ Distance to Success\n          Near           Far\n\n**Intuition**: Failed trajectories that are \"visually similar\" to successful ones (low distance) receive higher rewards, encouraging the policy to explore in promising directions.\n\n\n4. Advantage Calculation (GRPO)\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n**File**: ``siirl/dag_worker/core_algos.py``\n\n**Function**: ``compute_grpo_outcome_advantage()``\n\nGRPO computes advantages using group-relative normalization, eliminating the need for a Critic network.\n\n.. code-block:: python\n\n   @register_adv_est(AdvantageEstimator.GRPO)\n   def compute_grpo_outcome_advantage(\n       token_level_rewards: torch.Tensor,  # (B, response_length)\n       response_mask: torch.Tensor,        # (B, response_length)\n       index: np.ndarray,                  # (B,) - prompt index for grouping\n       epsilon: float = 1e-6,\n       norm_adv_by_std_in_grpo: bool = True,\n       config: Optional[AlgorithmArguments] = None,\n   ) -> tuple[torch.Tensor, torch.Tensor]:\n       \"\"\"\n       GRPO Advantage = (reward - group_mean) / group_std\n       \n       This is the \"Self-Referential\" part: the baseline is computed from\n       the policy's own samples, not from a separate Value network.\n       \"\"\"\n       # Sum rewards across response tokens to get scalar reward per sample\n       scores = token_level_rewards.sum(dim=-1)  # (B,)\n       \n       # Group samples by prompt index\n       id2score = defaultdict(list)\n       id2mean = {}\n       id2std = {}\n\n       with torch.no_grad():\n           bsz = scores.shape[0]\n       for i in range(bsz):\n               idx_key = int(index[i].item()) if isinstance(index[i], torch.Tensor) else int(index[i])\n               id2score[idx_key].append(scores[i])\n       \n       # Compute group statistics\n       for idx in id2score:\n               if len(id2score[idx]) == 1:\n                   id2mean[idx] = torch.tensor(0.0)\n                   id2std[idx] = torch.tensor(1.0)\n               elif len(id2score[idx]) > 1:\n           scores_tensor = torch.stack(id2score[idx])\n           id2mean[idx] = torch.mean(scores_tensor)\n           id2std[idx] = torch.std(scores_tensor)\n       \n       # Normalize: advantage = (score - mean) / std\n       for i in range(bsz):\n               idx_key = int(index[i].item()) if isinstance(index[i], torch.Tensor) else int(index[i])\n           if norm_adv_by_std_in_grpo:\n                   scores[i] = (scores[i] - id2mean[idx_key]) / (id2std[idx_key] + epsilon)\n               else:\n                   scores[i] = scores[i] - id2mean[idx_key]  # Dr.GRPO variant\n\n           # Broadcast to token level\n           scores = scores.unsqueeze(-1) * response_mask\n\n       return scores, scores  # (advantages, returns)\n\n**Embodied-specific handling in compute_advantage()**:\n\n.. code-block:: python\n\n   def compute_advantage(data: TensorDict, adv_estimator, ...):\n       if adv_estimator == AdvantageEstimator.GRPO:\n           if \"finish_step\" in data and data[\"responses\"].ndim == 3:\n               # Embodied scenario: compute mask based on finish_step\n               responses = data[\"responses\"]\n               batch_size = responses.size(0)\n               response_length = responses.size(1) * responses.size(2)  # traj_len * action_token_len\n\n               action_token_len = responses.size(2)\n               finish_step = data['finish_step'] * action_token_len\n\n               steps = torch.arange(response_length, device=responses.device)\n               steps_expanded = steps.unsqueeze(0).expand(batch_size, -1)\n               grpo_calculation_mask = steps_expanded < finish_step.unsqueeze(1)\n           else:\n               # NLP scenario: use attention_mask-based response_mask\n               grpo_calculation_mask = data[\"response_mask\"]\n\n           advantages, returns = compute_grpo_outcome_advantage(\n               token_level_rewards=data[\"token_level_rewards\"],\n               response_mask=grpo_calculation_mask,\n               index=data[\"uid\"],\n               norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n           )\n\n\n5. Policy Update (PPO Loss)\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n**File**: ``siirl/engine/actor/embodied_actor.py``\n\n**Class**: ``RobDataParallelPPOActor``\n\n**Method**: ``update_policy()``\n\nThe actor update uses the standard PPO clipped objective with GRPO advantages.\n\n.. code-block:: python\n\n   def update_policy(self, data: TensorDict):\n       self.actor_module.train()\n       self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu\n       temperature = data['temperature']\n\n       select_keys = ['responses', 'input_ids', 'attention_mask', 'pixel_values',\n                      'old_log_probs', 'advantages', \"finish_step\"]\n       batch = data.select(*select_keys)\n       dataloader = batch.split(self.config.ppo_mini_batch_size)\n\n       metrics = {}\n       for batch_idx, data in enumerate(dataloader):\n           mini_batch = data\n           micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)\n           self.actor_optimizer.zero_grad()\n\n           for test_idx, data in enumerate(micro_batches):\n               data = data.cuda()\n               responses = data['responses']\n               response_length = responses.size(1) * responses.size(2)\n\n               # Build response mask from finish_step\n               finish_step = data['finish_step'] * self.config.action_token_len\n               steps = torch.arange(response_length, device=responses.device)\n               steps_expanded = steps.unsqueeze(0).expand(responses.size(0), -1)\n               response_mask = steps_expanded < finish_step.unsqueeze(1)\n\n               old_log_prob = data['old_log_probs']\n               advantages = data['advantages']\n\n               # Split trajectory into mini-batches for memory efficiency\n               traj_len = responses.size(1)\n               traj_split_num = int(traj_len / self.config.traj_mini_batch_size)\n\n               for i in range(0, traj_len, int(traj_len / traj_split_num)):\n       # Forward pass to get current log probs\n       entropy, log_prob = self._forward_micro_batch_update(\n                       input_ids=input_ids[i:i+chunk_size],\n                       attention_mask=attention_mask[i:i+chunk_size],\n                       pixel_values=pixel_values[i:i+chunk_size],\n                       responses=responses[i:i+chunk_size],\n                       temperature=temperature\n       )\n       \n       # Compute PPO clipped loss\n       pg_loss, pg_clipfrac, ppo_kl, _ = core_algos.compute_policy_loss_vanilla(\n                       old_log_prob=old_log_prob_tmp,\n           log_prob=log_prob,\n                       advantages=advantages_tmp,\n                       response_mask=response_mask_tmp,\n           config=self.config\n       )\n       \n       loss = pg_loss / self.gradient_accumulation\n       loss.backward()\n       \n       grad_norm = self._optimizer_step()\n       return metrics\n\n**PPO Loss Function** (from ``core_algos.py``):\n\n.. code-block:: python\n\n   @register_policy_loss(\"vanilla\")\n   def compute_policy_loss_vanilla(\n       old_log_prob: torch.Tensor,\n       log_prob: torch.Tensor,\n       advantages: torch.Tensor,\n       response_mask: torch.Tensor,\n       config: Optional[ActorArguments] = None,\n       ...\n   ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n       \"\"\"\n       L^CLIP(θ) = E[min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t)]\n       \n       where r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t)\n       \"\"\"\n       clip_ratio = config.clip_ratio\n       clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio\n       clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio\n\n       # Importance ratio\n       negative_approx_kl = log_prob - old_log_prob\n       negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)  # stability\n       ratio = torch.exp(negative_approx_kl)\n       ppo_kl = siirl_F.masked_mean(-negative_approx_kl, response_mask)\n       \n       # Clipped objective\n       pg_losses1 = -advantages * ratio\n       pg_losses2 = -advantages * torch.clamp(ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)\n       clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)\n\n       # Dual-clip for negative advantages\n       pg_losses3 = -advantages * clip_ratio_c\n       clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)\n       pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)\n\n       pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n       \n       return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower\n\n\nKey Configuration Parameters\n----------------------------\n\n.. list-table::\n   :header-rows: 1\n   :widths: 35 30 35\n\n   * - Parameter\n     - Location\n     - Description\n   * - ``algorithm.adv_estimator``\n     - Training config\n     - Set to ``grpo`` for SRPO\n   * - ``actor_rollout_ref.rollout.n``\n     - Training script\n     - Group size (samples per prompt)\n   * - ``algorithm.embodied_sampling.filter_accuracy``\n     - Training script\n     - Enable accuracy-based filtering\n   * - ``algorithm.embodied_sampling.accuracy_lower_bound``\n     - Training script\n     - Min success rate (default: 0.1)\n   * - ``algorithm.embodied_sampling.accuracy_upper_bound``\n     - Training script\n     - Max success rate (default: 0.9)\n   * - ``algorithm.embodied_sampling.filter_truncated``\n     - Training script\n     - Filter truncated episodes\n   * - ``actor_rollout_ref.embodied.video_embedding_model_path``\n     - Training script\n     - Path to V-JEPA model\n   * - ``actor_rollout_ref.embodied.env.num_envs``\n     - Config\n     - Number of parallel environments\n   * - ``actor_rollout_ref.embodied.env.max_steps``\n     - Config\n     - Maximum steps per episode\n   * - ``actor_rollout_ref.embodied.action_chunks_len``\n     - Config\n     - Actions per VLA forward pass\n\n\nQuick Reference: File Locations\n-------------------------------\n\n.. list-table::\n   :header-rows: 1\n   :widths: 30 70\n\n   * - Component\n     - File Path\n   * - Training Entry\n     - ``siirl/main_dag.py``\n   * - **Pipeline Definition**\n     - ``siirl/execution/dag/builtin_pipelines.py``\n   * - **Embodied Rollout**\n     - ``siirl/engine/rollout/embodied_rollout.py``\n   * - Environment Adapter\n     - ``siirl/environment/embodied/adapters/``\n   * - V-JEPA Embedding\n     - ``siirl/utils/embodied/video_emb.py``\n   * - **Data Filtering**\n     - ``siirl/user_interface/filter_interface/embodied.py``\n   * - **VJEPA Reward**\n     - ``siirl/utils/reward_score/embodied.py``\n   * - **GRPO Advantage**\n     - ``siirl/dag_worker/core_algos.py``\n   * - **VLA Actor**\n     - ``siirl/engine/actor/embodied_actor.py``\n   * - Example Scripts\n     - ``examples/embodied_srpo_trainer/run_openvla_oft_*.sh``\n\n\nReferences\n----------\n\n1. SRPO Paper: `Self-Referential Policy Optimization for Vision-Language-Action Models <https://arxiv.org/pdf/2511.15605>`_\n2. V-JEPA: `Video Joint Embedding Predictive Architecture 2 <https://ai.meta.com/vjepa/>`_\n\n"
  },
  {
    "path": "docs/requirements-docs.txt",
    "content": "# markdown support\nrecommonmark\nmyst_parser\n# markdown table support\nsphinx-markdown-tables\n\n# theme default rtd\n\n# crate-docs-theme\nsphinx-rtd-theme\n\n# pin tokenizers version to avoid env_logger version req\ntokenizers==0.21"
  },
  {
    "path": "docs/start/install.rst",
    "content": "Installation\n============\n\nsiiRL provides three primary installation methods. We **strongly recommend** using the Docker image for the most reliable and hassle-free experience.\n\n* :ref:`Method 1: Install from Docker Image (Recommended) <install-docker>`\n* :ref:`Method 2: Install from PyPI (pip) <install-pip>`\n* :ref:`Method 3: Install from Source (Custom Environment) <install-source>`\n\nRequirements\n------------\n\n- **Python**: Version >= 3.10\n- **CUDA**: Version >= 12.1\n\nCurrently, siiRL supports the following configurations are available:\n\n- **FSDP** for training.\n- **SGLang** and **vLLM** for rollout generation.\n\n.. _install-docker:\n\nMethod 1: Install from docker image\n------------------------------------\n\nThe stable image is ``siiai/siirl-base:vllm0.8.5.post1-sglang0.4.6.post5-cu124``. This images contains the latest version of inference and training framework and its dependencies.\n\n.. _install-pip:\n\nMethod 2: Install from PIP\n---------------------------\n\nWe provide prebuilt python wheels for Linux. Install siiRL with the following command:\n\n.. code:: bash\n\n    # Install siiRL with vLLM\n    pip install siirl[vllm]\n\n    # Then, install required high-performance dependencies for siiRL\n    pip install flashinfer-python -i https://flashinfer.ai/whl/cu124/torch2.6/\n    pip install flash-attn==2.7.3 --no-build-isolation\n\n.. _install-source:\n\nMethod 3: Install from custom environment\n---------------------------------------------\n\nWe recommend to use docker images for convenience. However, if your environment is not compatible with the docker image, you can also install siirl in a python environment.\n\nInstall dependencies\n::::::::::::::::::::\n\n1. First of all, to manage environment, we recommend using conda:\n\n.. code:: bash\n\n   conda create -n siirl python==3.10\n   conda activate siirl\n\n2. Install python packages\n\n.. note::\n    The following commands are an example for an environment with CUDA 12.4.\n    If you are using a different CUDA version, you must adjust the package versions and index URLs accordingly, especially for torch, flashinfer, and flash-attn.\n    \n.. code:: bash\n\n    pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124\n    pip install flashinfer-python -i https://flashinfer.ai/whl/cu124/torch2.6/\n    pip install flash-attn==2.7.3 --no-build-isolation\n    pip install accelerate codetiming datasets dill hydra-core pandas wandb loguru tensorboard qwen_vl_utils\n    pip install 'ray[default]>=2.47.1'\n    pip install opentelemetry-exporter-prometheus==0.47b0\n\n\n3. Then, execute the following commands to install vLLM and SGLang:\n\n.. code:: bash\n\n    pip install vllm==0.8.5.post1\n\nInstall siirl\n::::::::::::::\n\nFor installing the latest version of siirl, the best way is to clone and\ninstall it from source. Then you can modify our code to customize your\nown post-training jobs.\n\n.. code:: bash\n\n   git clone https://github.com/sii-research/siiRL.git\n   cd siirl\n   pip install -e .\n\n"
  },
  {
    "path": "docs/start/quickstart.rst",
    "content": ".. _quickstart:\n\n=========================================================\nQuickstart: GRPO training on GSM8K dataset\n=========================================================\n\nPost-train a LLM using GSM8K dataset.\n\nIntroduction\n------------\n\n.. _hf_dataset_gsm8k: https://huggingface.co/datasets/gsm8k\n\nIn this example, we train an LLM to tackle the `GSM8k <hf_dataset_gsm8k>`_ task with function-based rewards. \n\nPrerequisite:\n\n- the latest version of ``siiRL`` and its dependencies installed following the installation guide. Using the docker image is recommended.\n\n- a GPU with at least 24 GB HBM\n\n\nDataset Introduction\n--------------------\n\nGSM8k is a math problem dataset. The prompt is an elementary school\nproblem. The LLM model is asked to solve the math problem. Below is an example:\n\nPrompt\n\n   Katy makes coffee using teaspoons of sugar and cups of water in the\n   ratio of 7:13. If she used a total of 120 teaspoons of sugar and cups\n   of water, calculate the number of teaspoonfuls of sugar she used.\n\nSolution\n\n   The total ratio representing the ingredients she used to make the\n   coffee is 7+13 = <<7+13=20>>20 Since the fraction representing the\n   number of teaspoons she used is 7/20, she used 7/20\\ *120 =\n   <<7/20*\\ 120=42>>42 #### 42\n\nStep 1: Prepare the dataset\n----------------------------\n\nWe preprocess the dataset in parquet format so that (1) it contains necessary fields for computing RL rewards and (2) is faster to read.\n\n.. code-block:: bash\n\n   python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k\n\nStep 2: Download a model for post-training\n-------------------------------------------\n\nIn this example, we start with the ``Qwen2.5-0.5B-Instruct`` model.\n\n.. code-block:: bash\n\n   python3 -c \"import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2.5-0.5B-Instruct')\"\n\nStep 3: Perform GRPO training with the instruct model\n----------------------------------------------------------------------\n\n**Reward Model/Function**\n\nWe use a pre-defined rule-based reward model. We force the model to produce a final\nanswer following 4 “#” as shown in the solution. We extract the final\nanswer from both the solution and model's output using regular\nexpression matching. We assign a reward of 1 to correct\nanswer, 0.0 to incorrect answer and 0 to no answer. \n\nFor more details, please refer to `siirl/utils/reward_score/gsm8k.py <https://github.com/sii-research/siiRL/blob/main/siirl/utils/reward_score/gsm8k.py>`_.\n\n**Training Script**\n\nNow let's run GRPO training with the dataset and model above. [1]_\n\n\nSet the ``data.train_files`` ,\\ ``data.val_files``, ``actor_rollout_ref.model.path`` and ``critic.model.path`` based on your dataset and model names or paths.\n\n.. code-block:: bash\n\n   python3 -m siirl.main_dag \\\n    algorithm.adv_estimator=grpo \\\n    data.train_files=$HOME/data/gsm8k/train.parquet \\\n    data.val_files=$HOME/data/gsm8k/test.parquet \\\n    data.train_batch_size=128 \\\n    data.max_prompt_length=2048 \\\n    data.max_response_length=4096 \\\n    data.filter_overlong_prompts=True \\\n    data.truncation='error' \\\n    data.shuffle=False \\\n    actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen2.5-0.5B-Instruct \\\n    actor_rollout_ref.actor.optim.lr=1e-6 \\\n    actor_rollout_ref.model.use_remove_padding=True \\\n    actor_rollout_ref.model.use_fused_kernels=False \\\n    actor_rollout_ref.actor.ppo_mini_batch_size=32 \\\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \\\n    actor_rollout_ref.actor.use_kl_loss=True \\\n    actor_rollout_ref.actor.grad_clip=0.5 \\\n    actor_rollout_ref.actor.clip_ratio=0.2 \\\n    actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n    actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n    actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n    actor_rollout_ref.rollout.name=vllm \\\n    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \\\n    actor_rollout_ref.rollout.max_model_len=8192 \\\n    actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n    actor_rollout_ref.rollout.enforce_eager=False \\\n    actor_rollout_ref.rollout.free_cache_engine=False \\\n    actor_rollout_ref.rollout.n=8 \\\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \\\n    actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n    algorithm.kl_ctrl.kl_coef=0.001 \\\n    algorithm.use_kl_in_reward=False \\\n    trainer.critic_warmup=0 \\\n    trainer.logger=['console','tensorboard']  \\\n    trainer.project_name=siirl_qwen2.5_0.5b_grpo \\\n    trainer.experiment_name=siirl_qwen2.5_0.5b_grpo_toy \\\n    trainer.n_gpus_per_node=1 \\\n    trainer.nnodes=1 \\\n    trainer.save_freq=200 \\\n    trainer.test_freq=10 \\\n    trainer.total_epochs=30 \\\n    trainer.resume_mode=auto \\\n    trainer.max_actor_ckpt_to_keep=1 \\\n    trainer.default_local_dir=ckpts/qwen2.5_0.5b/grpo/ \\\n    trainer.val_before_train=True 2>&1 | tee verl_demo.log\n\nYou are expected to see the following logs, indicating training in progress. The key metric ``val/test_score/openai/gsm8k`` is computed every ``trainer.test_freq`` steps:\n\n.. code-block:: bash\n\n    step:1 - training/epoch:1.000 - training/global_step:0.000 - training/rollout_probs_diff_max:0.373 - training/rollout_probs_diff_mean:0.004 - training/rollout_probs_diff_std:0.009 - actor/entropy_loss:0.438 - actor/grad_norm:0.221 - actor/lr:0.000 - actor/pg_clipfrac:0.000 - actor/pg_clipfrac_lower:0.000 - actor/pg_loss:0.003 - actor/ppo_kl:-0.000 - critic/advantages/max:1.789 - critic/advantages/mean:-0.002 - critic/advantages/min:-0.730 - critic/returns/max:1.789 - critic/returns/mean:-0.002 - critic/returns/min:-0.730 - critic/rewards/max:1.000 - critic/rewards/mean:0.013 - critic/rewards/min:0.000 - critic/score/max:1.000 - critic/score/mean:0.013 - critic/score/min:0.000 - perf/cpu_mem_used_gb:11.576 - perf/cpu_memory_used_gb:125.440 - perf/delta_time/actor:72.260 - perf/delta_time/actor_log_prob:10.829 - perf/delta_time/advantage:0.039 - perf/delta_time/compute_core_metrics:0.020 - perf/delta_time/data_loading:1.030 - perf/delta_time/get_data_from_buffer:0.001 - perf/delta_time/get_entry_node:0.000 - perf/delta_time/get_intern_data_actor_old_log_prob:0.000 - perf/delta_time/get_intern_data_actor_train:0.000 - perf/delta_time/get_intern_data_calculate_advantages:0.000 - perf/delta_time/get_intern_data_function_reward:0.000 - perf/delta_time/get_intern_data_reference_log_prob:0.000 - perf/delta_time/get_next_node:0.000 - perf/delta_time/graph_execution:128.358 - perf/delta_time/graph_loop_management:0.001 - perf/delta_time/graph_output_handling:0.002 - perf/delta_time/put_data_to_buffer:0.001 - perf/delta_time/put_intern_data_actor_old_log_prob:0.000 - perf/delta_time/put_intern_data_actor_train:0.000 - perf/delta_time/put_intern_data_calculate_advantages:0.000 - perf/delta_time/put_intern_data_function_reward:0.000 - perf/delta_time/put_intern_data_reference_log_prob:0.000 - perf/delta_time/reduce_metrics:0.036 - perf/delta_time/ref:28.170 - perf/delta_time/reference:28.172 - perf/delta_time/reset_data_buffer:0.038 - perf/delta_time/reset_intern_data_buffer:0.000 - perf/delta_time/reward:0.255 - perf/delta_time/rollout:16.797 - perf/delta_time/step:129.426 - perf/delta_time/step_barrier:0.001 - perf/max_mem_alloc_gb:34.832 - perf/max_mem_rsvd_gb:39.678 - perf/max_memory_allocated_gb:34.832 - perf/max_memory_reserved_gb:39.678 - perf/mfu/actor:0.023 - perf/mfu/actor_log_prob:0.052 - perf/mfu/ref:0.021 - perf/mfu/rollout:0.079 - response_length/clip_ratio:0.610 - response_length/max:256.000 - response_length/mean:232.029 - response_length/min:76.000 - prompt_length/clip_ratio:0.000 - prompt_length/max:189.000 - prompt_length/mean:104.727 - prompt_length/min:66.000 - perf/total_num_tokens:431047.000 - perf/time_per_step:129.426 - perf/throughput:3330.450\n    step:2 - training/epoch:1.000 - training/global_step:1.000 - training/rollout_probs_diff_max:0.326 - training/rollout_probs_diff_mean:0.004 - training/rollout_probs_diff_std:0.009 - actor/entropy_loss:0.432 - actor/grad_norm:0.210 - actor/lr:0.000 - actor/pg_clipfrac:0.000 - actor/pg_clipfrac_lower:0.000 - actor/pg_loss:0.004 - actor/ppo_kl:-0.000 - critic/advantages/max:1.789 - critic/advantages/mean:-0.004 - critic/advantages/min:-0.730 - critic/returns/max:1.789 - critic/returns/mean:-0.004 - critic/returns/min:-0.730 - critic/rewards/max:1.000 - critic/rewards/mean:0.013 - critic/rewards/min:0.000 - critic/score/max:1.000 - critic/score/mean:0.013 - critic/score/min:0.000 - perf/cpu_mem_used_gb:11.589 - perf/cpu_memory_used_gb:125.617 - perf/delta_time/actor:72.457 - perf/delta_time/actor_log_prob:10.689 - perf/delta_time/advantage:0.040 - perf/delta_time/compute_core_metrics:0.001 - perf/delta_time/data_loading:0.005 - perf/delta_time/get_data_from_buffer:0.001 - perf/delta_time/get_entry_node:0.000 - perf/delta_time/get_intern_data_actor_old_log_prob:0.000 - perf/delta_time/get_intern_data_actor_train:0.000 - perf/delta_time/get_intern_data_calculate_advantages:0.000 - perf/delta_time/get_intern_data_function_reward:0.000 - perf/delta_time/get_intern_data_reference_log_prob:0.000 - perf/delta_time/get_next_node:0.000 - perf/delta_time/graph_execution:123.794 - perf/delta_time/graph_loop_management:0.001 - perf/delta_time/graph_output_handling:0.002 - perf/delta_time/put_data_to_buffer:0.001 - perf/delta_time/put_intern_data_actor_old_log_prob:0.000 - perf/delta_time/put_intern_data_actor_train:0.000 - perf/delta_time/put_intern_data_calculate_advantages:0.000 - perf/delta_time/put_intern_data_function_reward:0.000 - perf/delta_time/put_intern_data_reference_log_prob:0.000 - perf/delta_time/reduce_metrics:0.001 - perf/delta_time/ref:24.271 - perf/delta_time/reference:24.273 - perf/delta_time/reset_data_buffer:0.005 - perf/delta_time/reset_intern_data_buffer:0.000 - perf/delta_time/reward:0.286 - perf/delta_time/rollout:16.043 - perf/delta_time/step:123.805 - perf/delta_time/step_barrier:0.001 - perf/max_mem_alloc_gb:36.362 - perf/max_mem_rsvd_gb:41.596 - perf/max_memory_allocated_gb:36.362 - perf/max_memory_reserved_gb:41.596 - perf/mfu/actor:0.023 - perf/mfu/actor_log_prob:0.053 - perf/mfu/ref:0.024 - perf/mfu/rollout:0.082 - response_length/clip_ratio:0.595 - response_length/max:256.000 - response_length/mean:230.901 - response_length/min:20.000 - prompt_length/clip_ratio:0.000 - prompt_length/max:215.000 - prompt_length/mean:105.098 - prompt_length/min:65.000 - perf/total_num_tokens:430078.000 - perf/time_per_step:123.805 - perf/throughput:3473.837\n\nBeside, we provides a formatted, easy-to-read summary of core performance metrics on rank 0. This provides a clear, separate view of the most important indicators.\n\n.. code-block:: bash\n\n   ========================= RANK(0): Core Performance Metrics (Step: 1) =========================\n\n   --- ⏱️  Overall Performance ---\n   Step Time                   : 129.426 s\n   Throughput (tokens/s)       : 3330.45\n   Total Tokens in Step        : 431047\n\n   --- 📈 Algorithm Metrics ---\n   Actor Entropy               : 0.4380\n   Critic Rewards (Mean/Min/Max): 0.013 / 0.000 / 1.000\n   Critic Scores (Mean/Min/Max): 0.013 / 0.000 / 1.000\n\n   --- 🔥 Model Flops Utilization (MFU) ---\n   Mean MFU                    : N/A\n   Actor Training MFU          : 0.023\n   Rollout MFU                 : 0.079\n   Reference Policy MFU        : 0.021\n   Actor LogProb MFU           : 0.052\n\n   --- 💾 Memory Usage ---\n   Max GPU Memory Allocated    : 34.83 GB\n   Max GPU Memory Reserved     : 39.68 GB\n   CPU Memory Used             : 11.58 GB\n\n   --- 📏 Sequence Lengths ---\n   Prompt Length (Mean/Max)    : 104.7 / 189\n   Response Length (Mean/Max)  : 232.0 / 256\n\n   ==================================================================================\n\nCheckout ``Algorithm Baselines`` page for full training and validation logs for reference.\n\n\nIf you encounter out of memory issues with HBM less than 32GB, enable the following configs would help:\n\n.. code-block:: bash\n\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \\\n    critic.ppo_micro_batch_size_per_gpu=1 \\\n\nFor the full set of configs, please refer to :ref:`config-explain-page` for detailed explanation and performance tuning.\n\n\n.. [1] More training script examples for FSDP backend are stored in `examples/ppo_trainer <https://github.com/sii-research/siiRL/tree/main/examples/ppo_trainer>`_ directory."
  },
  {
    "path": "docs/user_interface/filter_interface.rst",
    "content": "================\nFilter Interface\n================\n\nFilter interface is used for dynamic sampling and data filtering in Pipelines.\n\n**Location:** ``siirl/user_interface/filter_interface/``\n\nArchitecture Overview\n---------------------\n\n::\n\n                              Filter Interface Architecture\n   ==============================================================================\n\n   +------------------+     +-------------------+     +------------------+\n   |   Previous Node  |     |   Filter Node     |     |    Next Node     |\n   |   (e.g. Reward)  |---->|   (COMPUTE type)  |---->|  (e.g. Advantage)|\n   +------------------+     +-------------------+     +------------------+\n                                    |\n                                    v\n                            +---------------+\n                            | Filter Logic  |\n                            +---------------+\n                            | 1. Get batch  |\n                            | 2. Compute    |\n                            |    mask       |\n                            | 3. Apply      |\n                            |    filter     |\n                            | 4. Return     |\n                            |    NodeOutput |\n                            +---------------+\n\n   ==============================================================================\n\n   Filter Execution Flow:\n\n   Input Batch              Filter Function              Output\n   +-----------+            +-------------+            +-----------+\n   | samples   |            |             |            | filtered  |\n   | [0,1,2,3, |  ------->  |  mask =     |  ------->  | samples   |\n   |  4,5,6,7] |            |  [T,T,F,T,  |            | [0,1,3,5] |\n   +-----------+            |   F,T,F,F]  |            +-----------+\n                            +-------------+\n                                  |\n                                  v\n                            +-------------+\n                            |  Metrics:   |\n                            | kept_ratio  |\n                            | kept_groups |\n                            +-------------+\n\nBuilt-in Filters\n----------------\n\nDAPO Dynamic Sampling\n~~~~~~~~~~~~~~~~~~~~~\n\n**Location:** ``siirl/user_interface/filter_interface/dapo.py``\n\n**Function:** ``dynamic_sampling()``\n\nFilters zero-variance sample groups (all correct or all incorrect).\n\n**Flow Diagram:**\n\n::\n\n   Input: Batch with rewards grouped by uid (prompt)\n   +-----------------------------------------------------------+\n   |  uid=0: [1.0, 1.0, 1.0, 1.0]  -> std=0 -> FILTER OUT     |\n   |  uid=1: [1.0, 0.0, 1.0, 0.0]  -> std>0 -> KEEP           |\n   |  uid=2: [0.0, 0.0, 0.0, 0.0]  -> std=0 -> FILTER OUT     |\n   |  uid=3: [0.5, 0.8, 0.2, 0.9]  -> std>0 -> KEEP           |\n   +-----------------------------------------------------------+\n   Output: Only uid=1 and uid=3 samples remain\n\n**How it works:**\n\n1. Group samples by uid (prompt)\n2. Calculate variance for each group\n3. Filter groups with variance = 0\n\n**Configuration:**\n\n.. code-block:: bash\n\n   python -m siirl.main_dag \\\n     algorithm.workflow_type=DAPO \\\n     algorithm.filter_groups.enable=true \\\n     algorithm.filter_groups.metric=seq_final_reward\n\n**Usage in Pipeline:**\n\n.. code-block:: python\n\n   pipeline.add_node(\n       \"dynamic_sampling\",\n       func=\"siirl.user_interface.filter_interface.dapo:dynamic_sampling\",\n       deps=[\"function_reward\"],\n       node_type=NodeType.COMPUTE,\n       node_role=NodeRole.DYNAMIC_SAMPLING\n   )\n\n**Returned Metrics:**\n\n- ``dapo_sampling/kept_trajectories_ratio``\n- ``dapo_sampling/kept_groups``\n- ``dapo_sampling/total_groups``\n\nEmbodied AI Sampling\n~~~~~~~~~~~~~~~~~~~~\n\n**Location:** ``siirl/user_interface/filter_interface/embodied.py``\n\n**Function:** ``embodied_local_rank_sampling()``\n\nFilters Embodied AI data based on task completion and accuracy.\n\n**Flow Diagram:**\n\n::\n\n   Input: Embodied rollout batch\n   +-----------------------------------------------------------------------+\n   |                                                                       |\n   |  Step 1: verify() - Compute accuracy from 'complete' field            |\n   |  +-------------------------------------------------------------------+|\n   |  | Sample 0: complete=True  -> acc=1.0                               ||\n   |  | Sample 1: complete=False -> acc=0.0                               ||\n   |  | ...                                                               ||\n   |  +-------------------------------------------------------------------+|\n   |                                                                       |\n   |  Step 2: _filter_batch() - Apply filters                              |\n   |  +-------------------------------------------------------------------+|\n   |  | Accuracy Filter (per prompt group):                               ||\n   |  |   prompt_mean_acc >= lower_bound (0.1)  AND                       ||\n   |  |   prompt_mean_acc <= upper_bound (0.9)                            ||\n   |  |                                                                   ||\n   |  | Truncation Filter:                                                ||\n   |  |   finish_step < max_steps (not truncated)                         ||\n   |  +-------------------------------------------------------------------+|\n   |                                                                       |\n   +-----------------------------------------------------------------------+\n   Output: Filtered batch (only \"learnable\" samples)\n\n**Features:**\n\n- Task verification\n- Accuracy-based filtering\n- Truncated trajectory filtering\n\n**Configuration:**\n\n.. code-block:: bash\n\n   python -m siirl.main_dag \\\n     algorithm.workflow_type=EMBODIED \\\n     algorithm.embodied_sampling.filter_accuracy=true \\\n     algorithm.embodied_sampling.filter_truncated=true \\\n     algorithm.embodied_sampling.accuracy_lower_bound=0.0 \\\n     algorithm.embodied_sampling.accuracy_upper_bound=1.0 \\\n     actor_rollout_ref.embodied.env.max_steps=100\n\n**Usage in Pipeline:**\n\n.. code-block:: python\n\n   pipeline.add_node(\n       \"dynamic_sampling\",\n       func=\"siirl.user_interface.filter_interface.embodied:embodied_local_rank_sampling\",\n       deps=[\"rollout_actor\"],\n       node_type=NodeType.COMPUTE,\n       node_role=NodeRole.DYNAMIC_SAMPLING\n   )\n\nCustom Filter\n-------------\n\nBasic Template\n~~~~~~~~~~~~~~\n\n.. code-block:: python\n\n   from siirl.params import SiiRLArguments\n   from siirl.dag_worker.data_structures import NodeOutput\n   from siirl.data_coordinator.sample import filter_tensordict\n   import torch\n\n   def my_custom_filter(\n       config: SiiRLArguments,\n       batch,\n       **kwargs\n   ) -> NodeOutput:\n       \"\"\"Custom filter function\"\"\"\n\n       # Get data\n       rewards = batch.batch[\"rewards\"]\n\n       # Create filter mask\n       mask = rewards > threshold  # Boolean tensor\n\n       # Apply filter\n       filtered_batch = filter_tensordict(batch, mask)\n\n       # Collect metrics\n       metrics = {\n           \"filter/kept_ratio\": mask.sum().item() / len(mask)\n       }\n\n       return NodeOutput(batch=filtered_batch, metrics=metrics)\n\nExample: Reward Threshold Filter\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: python\n\n   def reward_threshold_filter(\n       config: SiiRLArguments,\n       batch,\n       **kwargs\n   ) -> NodeOutput:\n       \"\"\"Filter samples below reward threshold\"\"\"\n\n       rewards = batch.batch[\"rewards\"]\n       threshold = config.algorithm.filter_threshold\n\n       # Create mask\n       mask = rewards > threshold\n\n       # Apply filter\n       from siirl.data_coordinator.sample import filter_tensordict\n       filtered_batch = filter_tensordict(batch, mask)\n\n       # Metrics\n       metrics = {\n           \"filter/kept_ratio\": mask.sum().item() / len(mask),\n           \"filter/threshold\": threshold\n       }\n\n       return NodeOutput(batch=filtered_batch, metrics=metrics)\n\n**Configuration:**\n\n.. code-block:: bash\n\n   python -m siirl.main_dag \\\n     algorithm.filter_threshold=0.5\n\n**Usage in Pipeline:**\n\n.. code-block:: python\n\n   pipeline.add_node(\n       \"reward_filter\",\n       func=\"my_module:reward_threshold_filter\",\n       deps=[\"function_reward\"],\n       node_type=NodeType.COMPUTE,\n       node_role=NodeRole.DYNAMIC_SAMPLING\n   )\n\nExample: Group Variance Filter\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: python\n\n   from collections import defaultdict\n\n   def group_variance_filter(\n       config: SiiRLArguments,\n       batch,\n       **kwargs\n   ) -> NodeOutput:\n       \"\"\"Filter groups with low variance\"\"\"\n\n       rewards = batch.batch[\"rewards\"]\n       uids = batch.batch[\"uid\"]\n\n       # Group by uid\n       uid_to_rewards = defaultdict(list)\n       for i, uid in enumerate(uids):\n           uid_key = int(uid) if hasattr(uid, 'item') else uid\n           uid_to_rewards[uid_key].append(rewards[i].item())\n\n       # Calculate std for each group\n       min_std = config.algorithm.min_group_std\n       kept_uids = {\n           uid for uid, r in uid_to_rewards.items()\n           if torch.std(torch.tensor(r)).item() >= min_std\n       }\n\n       # Create mask\n       mask = torch.tensor([\n           (int(uids[i]) if hasattr(uids[i], 'item') else uids[i]) in kept_uids\n           for i in range(len(uids))\n       ], dtype=torch.bool)\n\n       # Apply filter\n       from siirl.data_coordinator.sample import filter_tensordict\n       filtered_batch = filter_tensordict(batch, mask)\n\n       metrics = {\n           \"filter/kept_groups\": len(kept_uids),\n           \"filter/total_groups\": len(uid_to_rewards)\n       }\n\n       return NodeOutput(batch=filtered_batch, metrics=metrics)\n"
  },
  {
    "path": "docs/user_interface/metrics_interface.rst",
    "content": "=================\nMetrics Interface\n=================\n\nCustom metrics allow you to track and aggregate any quantitative measures during training and validation. siiRL provides a distributed, Ray-based metrics system that automatically handles aggregation across all workers using various reduction operations (mean, max, min, sum).\n\nArchitecture Overview\n---------------------\n\n::\n\n                         Distributed Metrics Architecture\n   ==============================================================================\n\n   DAGWorker 0        DAGWorker 1        DAGWorker 2        DAGWorker N\n   +-----------+      +-----------+      +-----------+      +-----------+\n   | compute   |      | compute   |      | compute   |      | compute   |\n   | metrics   |      | metrics   |      | metrics   |      | metrics   |\n   +-----+-----+      +-----+-----+      +-----+-----+      +-----+-----+\n         |                  |                  |                  |\n         v                  v                  v                  v\n   +-----+-----+      +-----+-----+      +-----+-----+      +-----+-----+\n   | Metric    |      | Metric    |      | Metric    |      | Metric    |\n   | Client    |      | Client    |      | Client    |      | Client    |\n   +-----+-----+      +-----+-----+      +-----+-----+      +-----+-----+\n         |                  |                  |                  |\n         +------------------+------------------+------------------+\n                                    |\n                                    v\n                         +-------------------+\n                         |   MetricWorker    |  (Ray Actor - Singleton)\n                         |   (Aggregator)    |\n                         +-------------------+\n                         | - Collect metrics |\n                         | - Wait for all    |\n                         |   workers         |\n                         | - Aggregate:      |\n                         |   mean/max/min/   |\n                         |   sum             |\n                         +--------+----------+\n                                  |\n                                  v\n                         +-------------------+\n                         |  Final Metrics    |\n                         | (to Logger/WandB) |\n                         +-------------------+\n\n   ==============================================================================\n\n   Metrics Data Flow:\n\n   +-------------+     +----------------+     +----------------+     +--------+\n   | TensorDict  | --> | compute_*      | --> | MetricClient   | --> | Metric |\n   | (batch)     |     | _metric()      |     | .submit_metric |     | Worker |\n   +-------------+     +----------------+     +----------------+     +--------+\n                              |\n                              v\n                       +-------------+\n                       | Dict[str,   |\n                       |   float]    |\n                       | {name: val} |\n                       +-------------+\n\n   ==============================================================================\n\n**Key Files:**\n\n- ``siirl/execution/metric_worker/metric_worker.py`` - Ray-based distributed metrics aggregation\n- ``siirl/utils/metrics/metric_utils.py`` - Core metric computation functions\n- ``siirl/execution/metric_worker/utils.py`` - Aggregation function utilities\n\nQuick Start\n-----------\n\nMethod 1: Extending Core Metrics Functions (Recommended)\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n**Step 1:** Create your metric computation function in ``metric_utils.py``\n\n.. code-block:: python\n\n   # Add to siirl/utils/metrics/metric_utils.py\n\n   def compute_custom_data_metrics(data: TensorDict) -> Dict[str, float]:\n       \"\"\"Custom metrics computed from batch data\"\"\"\n       metrics = {}\n\n       # Token-level accuracy\n       if \"correct_tokens\" in data and \"attention_mask\" in data:\n           correct = data[\"correct_tokens\"].float()\n           mask = data[\"attention_mask\"].float()\n           accuracy = (correct * mask).sum() / mask.sum()\n           metrics[\"custom/token_accuracy/mean\"] = accuracy.item()\n\n       # Response quality score\n       if \"responses\" in data and \"response_mask\" in data:\n           response_quality = compute_response_quality_score(data)\n           metrics[\"custom/response_quality/mean\"] = response_quality.mean().item()\n           metrics[\"custom/response_quality/max\"] = response_quality.max().item()\n           metrics[\"custom/response_quality/min\"] = response_quality.min().item()\n\n       return metrics\n\n   def compute_response_quality_score(data: TensorDict) -> torch.Tensor:\n       \"\"\"Helper function to compute response quality\"\"\"\n       responses = data[\"responses\"]\n       response_mask = data[\"response_mask\"]\n\n       # Example: vocabulary diversity score\n       unique_tokens_per_response = []\n       for i in range(responses.shape[0]):\n           response_tokens = responses[i][response_mask[i].bool()]\n           unique_count = len(torch.unique(response_tokens))\n           unique_tokens_per_response.append(unique_count)\n\n       return torch.tensor(unique_tokens_per_response, device=responses.device).float()\n\n**Step 2:** Submit metrics using MetricClient\n\n.. code-block:: python\n\n   # Usage in your training loop\n   from siirl.execution.metric_worker.metric_worker import MetricClient\n\n   # In your DAG worker or training script\n   custom_metrics = compute_custom_data_metrics(batch)\n   metric_client.submit_metric(custom_metrics, world_size)\n\nCurrent Metrics System\n----------------------\n\nBuilt-in Metrics Reference\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe following tables list all built-in metrics provided by siiRL.\n\n**Data Metrics** (from ``compute_data_metric`` in ``metric_utils.py``):\n\n.. list-table:: Critic Metrics\n   :header-rows: 1\n   :widths: 40 60\n\n   * - Metric Name\n     - Description\n   * - ``critic/score/mean|max|min``\n     - Sequence-level scores from token-level scores\n   * - ``critic/rewards/mean|max|min``\n     - Sequence-level rewards from token-level rewards\n   * - ``critic/advantages/mean|max|min``\n     - Advantages (masked by response_mask)\n   * - ``critic/returns/mean|max|min``\n     - Returns (masked by response_mask)\n   * - ``critic/values/mean|max|min``\n     - Value function estimates (if available)\n   * - ``critic/vf_explained_var``\n     - Explained variance of value function\n\n.. list-table:: Response Analysis Metrics\n   :header-rows: 1\n   :widths: 40 60\n\n   * - Metric Name\n     - Description\n   * - ``response/length/mean|max|min``\n     - Response token lengths\n   * - ``response/clip_ratio/mean``\n     - Proportion hitting max response length\n   * - ``response/correct_length/mean|max|min``\n     - Lengths for responses with reward > 0.5\n   * - ``response/wrong_length/mean|max|min``\n     - Lengths for responses with reward ≤ 0.5\n\n.. list-table:: Prompt Analysis Metrics\n   :header-rows: 1\n   :widths: 40 60\n\n   * - Metric Name\n     - Description\n   * - ``prompt/length/mean|max|min``\n     - Prompt token lengths\n   * - ``prompt/clip_ratio/mean``\n     - Proportion hitting max prompt length\n\n.. list-table:: System & Multi-turn Metrics\n   :header-rows: 1\n   :widths: 40 60\n\n   * - Metric Name\n     - Description\n   * - ``perf/process_cpu_mem_used_gb``\n     - CPU memory usage per process\n   * - ``num_turns/min|max|mean``\n     - Statistics for multi-turn conversations\n\n**Timing Metrics** (from ``compute_timing_metrics``):\n\n.. list-table::\n   :header-rows: 1\n   :widths: 40 60\n\n   * - Metric Name\n     - Description\n   * - ``timing_s/{stage}``\n     - Raw timing in seconds for each stage\n   * - ``timing_per_token_ms/{stage}``\n     - Per-token timing in milliseconds\n\nStages: ``gen``, ``ref``, ``values``, ``adv``, ``update_critic``, ``update_actor``\n\n**Throughput Metrics** (from ``compute_throughout_metrics``):\n\n.. list-table::\n   :header-rows: 1\n   :widths: 40 60\n\n   * - Metric Name\n     - Description\n   * - ``perf/total_num_tokens``\n     - Total tokens processed\n   * - ``perf/time_per_step``\n     - Time per training step\n   * - ``perf/throughput``\n     - Tokens per second per GPU\n\n**Validation Metrics** (from ``process_validation_metrics``):\n\n.. list-table::\n   :header-rows: 1\n   :widths: 40 60\n\n   * - Metric Name\n     - Description\n   * - ``val-core/{data_source}/{var}/mean@N``\n     - Mean across N samples\n   * - ``val-core/{data_source}/{var}/best@N/mean|std``\n     - Bootstrap best-of-N statistics\n   * - ``val-core/{data_source}/{var}/worst@N/mean|std``\n     - Bootstrap worst-of-N statistics\n   * - ``val-core/{data_source}/{var}/maj@N/mean|std``\n     - Bootstrap majority voting statistics\n   * - ``val/test_score/{data_source}``\n     - Test score per data source\n\nCustom Metrics Implementation\n-----------------------------\n\nMethod 1: Custom Data Metrics\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nExtend the data metrics computed from training batches:\n\n.. code-block:: python\n\n   # Add to metric_utils.py\n   def compute_custom_training_metrics(data: TensorDict) -> Dict[str, float]:\n       \"\"\"Custom training-specific metrics\"\"\"\n       metrics = {}\n\n       # Policy entropy (exploration measure)\n       if \"policy_logits\" in data:\n           logits = data[\"policy_logits\"]\n           probs = torch.softmax(logits, dim=-1)\n           entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=-1)\n           response_mask = data.get(\"response_mask\", torch.ones_like(entropy))\n\n           # Only compute entropy for response tokens\n           masked_entropy = entropy * response_mask.float()\n           valid_entropy = masked_entropy.sum() / response_mask.sum()\n\n           metrics[\"training/policy_entropy/mean\"] = valid_entropy.item()\n\n       # Gradient norm tracking\n       if \"grad_norm\" in data:\n           metrics[\"training/grad_norm/mean\"] = data[\"grad_norm\"].item()\n\n       # Loss convergence tracking\n       if \"loss_values\" in data:\n           loss_values = data[\"loss_values\"]\n           metrics[\"training/loss/mean\"] = loss_values.mean().item()\n           metrics[\"training/loss/std\"] = loss_values.std().item()\n\n       return metrics\n\n   # Usage in MetricClient.compute_local_data_metric\n   def compute_local_data_metric(self, data: TensorDict, world_size: int):\n       # Standard metrics\n       standard_metrics = compute_data_metric(data)\n\n       # Add custom metrics\n       custom_metrics = compute_custom_training_metrics(data)\n\n       # Combine and submit\n       all_metrics = {**standard_metrics, **custom_metrics}\n       self.submit_metric(all_metrics, world_size)\n\nMethod 2: Custom Validation Metrics\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nAdd custom validation metrics with bootstrap sampling:\n\n.. code-block:: python\n\n   # Add to metric_utils.py\n   def compute_custom_validation_metrics(\n       data_sources: list[str],\n       sample_inputs: list[str],\n       infos_dict: dict[str, list],\n       sample_turns: list[int]\n   ) -> dict[str, float]:\n       \"\"\"Custom validation metrics with bootstrap analysis\"\"\"\n\n       # Extract custom fields from infos_dict\n       custom_metrics = {}\n\n       if \"custom_score\" in infos_dict:\n           # Group by data source\n           source_scores = defaultdict(list)\n           for i, source in enumerate(data_sources):\n               source_scores[source].append(infos_dict[\"custom_score\"][i])\n\n           # Compute statistics per source\n           for source, scores in source_scores.items():\n               if len(scores) > 0:\n                   custom_metrics[f\"val/custom_score/{source}/mean\"] = np.mean(scores)\n                   custom_metrics[f\"val/custom_score/{source}/std\"] = np.std(scores)\n\n                   # Bootstrap sampling for confidence intervals\n                   if len(scores) > 1:\n                       bootstrap_results = bootstrap_metric(\n                           data=scores,\n                           subset_size=min(5, len(scores)),\n                           reduce_fns=[np.mean, np.max, np.min],\n                           n_bootstrap=1000\n                       )\n                       custom_metrics[f\"val/custom_score/{source}/bootstrap_mean\"] = bootstrap_results[0][0]\n                       custom_metrics[f\"val/custom_score/{source}/bootstrap_mean_std\"] = bootstrap_results[0][1]\n\n       # Conversation quality for multi-turn\n       if \"conversation_quality\" in infos_dict and len(sample_turns) > 0:\n           quality_by_turns = defaultdict(list)\n           for i, turns in enumerate(sample_turns):\n               if i < len(infos_dict[\"conversation_quality\"]):\n                   quality_by_turns[turns].append(infos_dict[\"conversation_quality\"][i])\n\n           for turn_count, qualities in quality_by_turns.items():\n               if len(qualities) > 0:\n                   custom_metrics[f\"val/conversation_quality/turns_{turn_count}/mean\"] = np.mean(qualities)\n\n       return custom_metrics\n\n   # Usage in MetricClient.process_local_validation_metrics\n   def process_local_validation_metrics(self, data_sources, sample_inputs, infos_dict, sample_turns, world_size):\n       # Standard validation metrics\n       standard_metrics = process_validation_metrics(data_sources, sample_inputs, infos_dict, sample_turns)\n\n       # Add custom validation metrics\n       custom_metrics = compute_custom_validation_metrics(data_sources, sample_inputs, infos_dict, sample_turns)\n\n       # Combine and submit\n       all_metrics = {**standard_metrics, **custom_metrics}\n       self.submit_metric(all_metrics, world_size)\n\nMethod 3: Custom Aggregation Logic\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nCreate custom aggregation functions for specialized reduction operations:\n\n.. code-block:: python\n\n   # Add to execution/metric_worker/utils.py\n   def MedianMetric(metrics: List[Metric]):\n       \"\"\"Custom median aggregation\"\"\"\n       values = [v for metric in metrics\n                for v in (metric.value if isinstance(metric.value, list) else [metric.value])]\n       return float(torch.median(torch.tensor(values)).item())\n\n   def PercentileMetric(percentile: float):\n       \"\"\"Custom percentile aggregation factory\"\"\"\n       def _percentile_metric(metrics: List[Metric]):\n           values = [v for metric in metrics\n                    for v in (metric.value if isinstance(metric.value, list) else [metric.value])]\n           return float(torch.quantile(torch.tensor(values), percentile / 100.0).item())\n       return _percentile_metric\n\n   # Update MetricFunc to handle custom aggregations\n   def MetricFunc(name: str):\n       if \"median\" in name:\n           return MedianMetric\n       elif \"p95\" in name:\n           return PercentileMetric(95)\n       elif \"p99\" in name:\n           return PercentileMetric(99)\n       elif \"min\" in name:\n           return MinMetric\n       elif \"max\" in name:\n           return MaxMetric\n       elif \"sum\" in name or \"total\" in name:\n           return SumMetric\n       else:\n           return MeanMetric\n\n   # Usage: name your metrics to trigger specific aggregations\n   metrics = {\n       \"custom/latency/median\": latency_values,  # Will use MedianMetric\n       \"custom/score/p95\": score_values,         # Will use 95th percentile\n   }\n\nMethod 4: Complex Custom Metrics\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nFor more sophisticated metrics requiring multiple computation steps:\n\n.. code-block:: python\n\n   # Add to metric_utils.py\n   def compute_advanced_metrics(data: TensorDict) -> Dict[str, float]:\n       \"\"\"Advanced metrics requiring complex computation\"\"\"\n       metrics = {}\n\n       # Sequence coherence analysis\n       if \"responses\" in data and \"attention_mask\" in data:\n           coherence_scores = compute_sequence_coherence(data)\n           metrics.update({\n               \"analysis/coherence/mean\": coherence_scores.mean().item(),\n               \"analysis/coherence/std\": coherence_scores.std().item(),\n               \"analysis/coherence/median\": coherence_scores.median().item(),\n           })\n\n       # Token transition analysis\n       if \"responses\" in data:\n           transition_metrics = analyze_token_transitions(data)\n           metrics.update(transition_metrics)\n\n       # Reward distribution analysis\n       if \"token_level_rewards\" in data:\n           reward_dist_metrics = analyze_reward_distribution(data)\n           metrics.update(reward_dist_metrics)\n\n       return metrics\n\n   def compute_sequence_coherence(data: TensorDict) -> torch.Tensor:\n       \"\"\"Compute coherence score for each sequence\"\"\"\n       responses = data[\"responses\"]\n       attention_mask = data[\"attention_mask\"]\n       batch_size = responses.shape[0]\n\n       coherence_scores = []\n       for i in range(batch_size):\n           # Extract valid tokens for this sequence\n           valid_length = attention_mask[i].sum().item()\n           sequence = responses[i][:valid_length]\n\n           # Compute local coherence (e.g., token transition smoothness)\n           if len(sequence) > 1:\n               # Simplified coherence: variance in token values\n               coherence = 1.0 / (1.0 + torch.var(sequence.float()).item())\n           else:\n               coherence = 1.0\n\n           coherence_scores.append(coherence)\n\n       return torch.tensor(coherence_scores, device=responses.device)\n\n   def analyze_token_transitions(data: TensorDict) -> Dict[str, float]:\n       \"\"\"Analyze patterns in token transitions\"\"\"\n       responses = data[\"responses\"]\n       response_mask = data.get(\"response_mask\", torch.ones_like(responses))\n\n       # Count unique transitions\n       unique_transitions = set()\n       total_transitions = 0\n\n       for i in range(responses.shape[0]):\n           response_tokens = responses[i][response_mask[i].bool()]\n           if len(response_tokens) > 1:\n               for j in range(len(response_tokens) - 1):\n                   transition = (response_tokens[j].item(), response_tokens[j+1].item())\n                   unique_transitions.add(transition)\n                   total_transitions += 1\n\n       diversity_ratio = len(unique_transitions) / max(total_transitions, 1)\n\n       return {\n           \"analysis/transition_diversity/mean\": diversity_ratio,\n           \"analysis/unique_transitions/total\": len(unique_transitions),\n           \"analysis/total_transitions/total\": total_transitions,\n       }\n\nIntegration with Training Workflow\n----------------------------------\n\nMetricClient Usage Pattern\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe ``MetricClient`` provides the main interface for submitting metrics:\n\n.. code-block:: python\n\n   from siirl.execution.metric_worker.metric_worker import MetricClient, MetricWorker\n\n   # Initialize metric worker and client\n   metric_worker = MetricWorker.remote()\n   await metric_worker.start.remote()\n   metric_client = MetricClient(metric_worker)\n\n   # During training loop\n   for step, batch in enumerate(dataloader):\n       # ... training logic ...\n\n       # Submit standard metrics\n       metric_client.compute_local_data_metric(batch, world_size)\n\n       # Submit custom metrics\n       custom_metrics = compute_advanced_metrics(batch)\n       metric_client.submit_metric(custom_metrics, world_size)\n\n       # Submit timing metrics\n       timing_data = {\"step\": step_time, \"forward\": forward_time}\n       metric_client.compute_local_timing_metrics(batch, timing_data, world_size)\n\n       # Wait for metrics to be processed\n       metric_client.wait_submit()\n\n   # Get final aggregated results\n   final_metrics = metric_client.wait_final_res()\n\nRay-based Distributed Aggregation\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe system uses Ray actors for distributed metrics processing:\n\n**MetricWorker Actor:**\n- Runs asynchronously to collect metrics from all workers\n- Aggregates metrics when all processes have submitted values\n- Supports different aggregation functions (mean, max, min, sum)\n- Automatically handles timing metric renaming (``timing_s/`` → ``perf/delta_time/``)\n\n**Aggregation Logic:**\n- Metrics are collected in a queue until all workers (``world_size``) submit\n- Each metric triggers computation when the expected number of submissions is reached\n- Final results are stored and returned when requested\n\nSpecial Metric Configurations\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nSome metrics require special aggregation logic:\n\n.. code-block:: python\n\n   # In metric_worker.py\n   Special_Metric = {\n       \"graph_output_handling\": MaxMetric,  # Only rollout_tp 0 contributes\n   }\n\nCustom metrics can be added to this dictionary for specialized handling.\n\nAdvanced Examples\n-----------------\n\nExample 1: Model Performance Analysis\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: python\n\n   def compute_model_performance_metrics(data: TensorDict, model_outputs: dict) -> Dict[str, float]:\n       \"\"\"Comprehensive model performance analysis\"\"\"\n       metrics = {}\n\n       # Attention pattern analysis\n       if \"attention_weights\" in model_outputs:\n           attention_weights = model_outputs[\"attention_weights\"]\n\n           # Attention concentration (how focused is attention)\n           attention_entropy = -torch.sum(\n               attention_weights * torch.log(attention_weights + 1e-9), dim=-1\n           )\n           metrics[\"model/attention_entropy/mean\"] = attention_entropy.mean().item()\n\n           # Attention on different token types\n           if \"attention_mask\" in data:\n               prompt_attention = attention_weights[:, :, :-data[\"responses\"].shape[-1]]\n               response_attention = attention_weights[:, :, -data[\"responses\"].shape[-1]:]\n\n               metrics[\"model/prompt_attention_ratio/mean\"] = (\n                   prompt_attention.sum() / attention_weights.sum()\n               ).item()\n\n       # Hidden state analysis\n       if \"hidden_states\" in model_outputs:\n           hidden_states = model_outputs[\"hidden_states\"]\n\n           # Representation diversity\n           layer_norms = torch.norm(hidden_states, dim=-1)\n           metrics[\"model/hidden_norm/mean\"] = layer_norms.mean().item()\n           metrics[\"model/hidden_norm/std\"] = layer_norms.std().item()\n\n       return metrics\n\nExample 2: Conversation Quality Assessment\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: python\n\n   def compute_conversation_quality_metrics(data: TensorDict) -> Dict[str, float]:\n       \"\"\"Multi-dimensional conversation quality assessment\"\"\"\n       metrics = {}\n\n       if \"responses\" not in data or \"prompts\" not in data:\n           return metrics\n\n       responses = data[\"responses\"]\n       prompts = data[\"prompts\"]\n       response_mask = data.get(\"response_mask\", torch.ones_like(responses))\n\n       batch_size = responses.shape[0]\n       quality_scores = []\n\n       for i in range(batch_size):\n           # Extract actual tokens (remove padding)\n           response_tokens = responses[i][response_mask[i].bool()]\n           prompt_tokens = prompts[i]\n\n           # Length appropriateness (not too short, not too long)\n           response_length = len(response_tokens)\n           length_score = compute_length_appropriateness(response_length)\n\n           # Vocabulary richness\n           unique_tokens = len(torch.unique(response_tokens))\n           vocab_score = min(unique_tokens / response_length, 1.0) if response_length > 0 else 0\n\n           # Repetition penalty\n           repetition_score = compute_repetition_score(response_tokens)\n\n           # Overall quality\n           quality = 0.3 * length_score + 0.3 * vocab_score + 0.4 * repetition_score\n           quality_scores.append(quality)\n\n       quality_tensor = torch.tensor(quality_scores, device=responses.device)\n\n       return {\n           \"conversation/quality/mean\": quality_tensor.mean().item(),\n           \"conversation/quality/std\": quality_tensor.std().item(),\n           \"conversation/quality/min\": quality_tensor.min().item(),\n           \"conversation/quality/max\": quality_tensor.max().item(),\n       }\n\n   def compute_length_appropriateness(length: int, target_length: int = 50) -> float:\n       \"\"\"Compute how appropriate the response length is\"\"\"\n       if length == 0:\n           return 0.0\n       ratio = length / target_length\n       if ratio <= 1.0:\n           return ratio  # Shorter is better than longer\n       else:\n           return 1.0 / ratio  # Penalize overly long responses\n\n   def compute_repetition_score(tokens: torch.Tensor) -> float:\n       \"\"\"Compute score based on repetition patterns\"\"\"\n       if len(tokens) <= 1:\n           return 1.0\n\n       # Count repeated bigrams\n       bigrams = set()\n       repeated_bigrams = 0\n\n       for i in range(len(tokens) - 1):\n           bigram = (tokens[i].item(), tokens[i+1].item())\n           if bigram in bigrams:\n               repeated_bigrams += 1\n           else:\n               bigrams.add(bigram)\n\n       # Higher repetition = lower score\n       repetition_ratio = repeated_bigrams / (len(tokens) - 1)\n       return 1.0 - repetition_ratio\n\nConfiguration and Best Practices\n---------------------------------\n\nMetric Naming Conventions\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nFollow these conventions for consistent metric organization:\n\n.. code-block:: text\n\n   # Training metrics\n   training/{category}/{metric_name}/{aggregation}\n\n   # Validation metrics\n   val/{category}/{data_source}/{metric_name}\n   val-core/{data_source}/{variable}/{metric_name}\n   val-aux/{category}/{metric_name}\n\n   # Performance metrics\n   perf/{metric_name}\n\n   # Analysis metrics\n   analysis/{category}/{metric_name}/{aggregation}\n\n   # Model introspection\n   model/{component}/{metric_name}/{aggregation}\n\nAggregation Selection\n~~~~~~~~~~~~~~~~~~~~~\n\nChoose aggregation methods based on metric semantics:\n\n- **mean**: Default for most metrics (accuracy, loss, etc.)\n- **max**: For peak values (max memory, worst-case latency)\n- **min**: For best-case scenarios (min loss, fastest response)\n- **sum/total**: For cumulative values (total tokens, total time)\n- **median**: For robust central tendency (when outliers matter)\n- **p95/p99**: For percentile-based SLA metrics\n\nError Handling\n~~~~~~~~~~~~~~\n\nAlways implement robust error handling:\n\n.. code-block:: python\n\n   def compute_safe_custom_metrics(data: TensorDict) -> Dict[str, float]:\n       \"\"\"Example of safe metric computation\"\"\"\n       metrics = {}\n\n       try:\n           # Check data availability\n           if \"required_field\" not in data:\n               return metrics\n\n           # Handle empty tensors\n           values = data[\"required_field\"]\n           if values.numel() == 0:\n               return metrics\n\n           # Compute metrics with numerical stability\n           mean_val = torch.mean(values.float())\n           if torch.isfinite(mean_val):\n               metrics[\"custom/metric/mean\"] = mean_val.item()\n\n       except Exception as e:\n           # Log error but don't crash training\n           print(f\"Error computing custom metrics: {e}\")\n           return {}\n\n       return metrics\n\nPerformance Considerations\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n- **Batch Processing**: Compute metrics on entire batches, not individual samples\n- **Device Placement**: Keep tensors on the same device as input data\n- **Memory Management**: Avoid accumulating large tensors across steps\n- **Async Processing**: Use Ray actors for non-blocking metrics aggregation\n- **Selective Computation**: Only compute expensive metrics when needed\n\nDebugging Custom Metrics\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: python\n\n   import os\n\n   def debug_custom_metrics(data: TensorDict, metrics: Dict[str, float]):\n       \"\"\"Debug utility for custom metrics\"\"\"\n       if os.environ.get(\"DEBUG_METRICS\", \"0\") == \"1\":\n           print(f\"Data keys: {list(data.keys())}\")\n           print(f\"Data shapes: {[(k, v.shape if hasattr(v, 'shape') else type(v)) for k, v in data.items()]}\")\n           print(f\"Computed metrics: {metrics}\")\n\n           # Check for common issues\n           for name, value in metrics.items():\n               if not isinstance(value, (int, float)):\n                   print(f\"WARNING: Metric {name} has invalid type {type(value)}\")\n               elif not np.isfinite(value):\n                   print(f\"WARNING: Metric {name} is not finite: {value}\")\n\nFile Structure Summary\n----------------------\n\n.. code-block:: text\n\n   siirl/execution/metric_worker/\n   ├── metric_worker.py          # Ray actor for distributed aggregation\n   │   ├── MetricWorker          # Ray remote actor class\n   │   └── MetricClient          # Client interface\n   └── utils.py                  # Aggregation functions\n       ├── Metric                # Dataclass for metric values\n       ├── MetricFunc            # Function selection logic\n       ├── MeanMetric            # Mean aggregation\n       ├── MaxMetric             # Maximum aggregation\n       ├── MinMetric             # Minimum aggregation\n       └── SumMetric             # Sum aggregation\n\n   siirl/utils/metrics/\n   └── metric_utils.py           # Core metric computation\n       ├── compute_data_metric           # Standard training metrics\n       ├── compute_timing_metrics        # Timing analysis\n       ├── compute_throughout_metrics    # Throughput analysis\n       ├── process_validation_metrics    # Validation with bootstrap\n       ├── bootstrap_metric             # Bootstrap sampling utility\n       └── aggregate_validation_metrics  # Parallel validation processing\n\nThis architecture provides a scalable, flexible foundation for comprehensive metrics collection in distributed training environments."
  },
  {
    "path": "docs/user_interface/pipeline_interface.rst",
    "content": "============\nPipeline API\n============\n\nPipeline is a declarative Python API for defining training workflows. Each Pipeline consists of Nodes connected through dependencies to form a DAG.\n\nArchitecture Overview\n---------------------\n\n::\n\n                            Pipeline Architecture\n   ==============================================================================\n\n   +------------------+                      +------------------+\n   |    Pipeline      |     .build()         |   TaskGraph      |\n   |    (Builder)     | ------------------> |     (DAG)        |\n   +------------------+                      +------------------+\n   | - pipeline_id    |                      | - graph_id       |\n   | - description    |                      | - nodes: Dict    |\n   | - _nodes: Dict   |                      | - adj: Dict      |\n   +------------------+                      | - rev_adj: Dict  |\n                                             +------------------+\n                                                     |\n                                                     | executed by\n                                                     v\n                                             +------------------+\n                                             |   DAGWorker      |\n                                             |   (per GPU)      |\n                                             +------------------+\n\n   ==============================================================================\n\n   Built-in Pipelines Comparison:\n\n   +----------+------------------------------------------------------------------+\n   | Pipeline | Nodes Flow                                                       |\n   +----------+------------------------------------------------------------------+\n   | GRPO     | rollout -> reward -> advantage -> old_log -> ref_log -> train    |\n   +----------+------------------------------------------------------------------+\n   | PPO      | rollout -> reward -> value -> advantage -> old_log -> ref_log    |\n   |          |         -> train_actor -> train_critic                           |\n   +----------+------------------------------------------------------------------+\n   | DAPO     | rollout -> reward -> dynamic_sampling -> advantage -> old_log    |\n   |          |         -> ref_log -> train                                      |\n   +----------+------------------------------------------------------------------+\n   | Embodied | rollout -> embodied_sampling -> reward -> advantage -> old_log   |\n   | SRPO     |         -> ref_log -> train                                      |\n   +----------+------------------------------------------------------------------+\n\nBasic Usage\n-----------\n\nCreating a Pipeline\n~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: python\n\n   from siirl.execution.dag.pipeline import Pipeline\n   from siirl.execution.dag.node import NodeType, NodeRole\n\n   pipeline = Pipeline(\"my_pipeline\", \"Description\")\n\n   # Add nodes (supports chaining)\n   pipeline.add_node(\n       \"node_id\",\n       func=\"module:function\",  # or \"module:Class.method\"\n       deps=[\"dependency_node_ids\"],\n       node_type=NodeType.COMPUTE,\n       node_role=NodeRole.DEFAULT\n   ).add_node(\n       \"next_node\",\n       func=\"module:another_function\",\n       deps=[\"node_id\"],\n       node_type=NodeType.MODEL_TRAIN,\n       node_role=NodeRole.ACTOR\n   )\n\n   # Build TaskGraph\n   task_graph = pipeline.build()\n\nNode Parameters\n~~~~~~~~~~~~~~~\n\n- ``node_id``: Unique identifier\n- ``func``: Function path (``\"module:function\"`` or ``\"module:Class.method\"``)\n- ``deps``: List of dependency node IDs\n- ``node_type``: MODEL_INFERENCE / MODEL_TRAIN / COMPUTE / DATA_LOAD\n- ``node_role``: ROLLOUT / ACTOR / CRITIC / REFERENCE / REWARD / ADVANTAGE / DYNAMIC_SAMPLING / DEFAULT\n- ``only_forward_compute``: Forward only (default False)\n\nBuilt-in Pipelines\n------------------\n\nsiiRL provides 4 built-in pipelines in ``siirl/execution/dag/builtin_pipelines.py``:\n\nGRPO Pipeline\n~~~~~~~~~~~~~\n\n**Workflow:** rollout → reward → advantage → old_log_prob → ref_log_prob → train_actor\n\n**Usage:**\n\n.. code-block:: bash\n\n   python -m siirl.main_dag \\\n     algorithm.adv_estimator=grpo\n\nPPO Pipeline\n~~~~~~~~~~~~\n\n**Workflow:** rollout → reward → critic_value → advantage → old_log_prob → ref_log_prob → train_actor → train_critic\n\n**Key Difference:** Adds value function and critic training\n\n**Usage:**\n\n.. code-block:: bash\n\n   python -m siirl.main_dag \\\n     algorithm.adv_estimator=gae \\\n     critic.enable=true\n\nDAPO Pipeline\n~~~~~~~~~~~~~\n\n**Workflow:** rollout → reward → dynamic_sampling → advantage → old_log_prob → ref_log_prob → train_actor\n\n**Key Feature:** Filters zero-variance sample groups\n\n**Usage:**\n\n.. code-block:: bash\n\n   python -m siirl.main_dag \\\n     algorithm.workflow_type=DAPO \\\n     algorithm.filter_groups.enable=true\n\nEmbodied GRPO Pipeline\n~~~~~~~~~~~~~~~~~~~~~~~\n\n**Workflow:** rollout → embodied_sampling → reward → advantage → old_log_prob → ref_log_prob → train_actor\n\n**Key Feature:** Embodied AI specific filtering\n\n**Usage:**\n\n.. code-block:: bash\n\n   python -m siirl.main_dag \\\n     algorithm.workflow_type=EMBODIED\n\nCustom Pipeline Definition\n---------------------------\n\nDefine Custom Pipeline\n~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: python\n\n   from siirl.execution.dag.pipeline import Pipeline\n   from siirl.execution.dag.task_graph import TaskGraph\n   from siirl.execution.dag.node import NodeType, NodeRole\n\n   def my_custom_pipeline() -> TaskGraph:\n       pipeline = Pipeline(\"my_pipeline\", \"My workflow\")\n\n       pipeline.add_node(\n           \"rollout_actor\",\n           func=\"siirl.dag_worker.dagworker:DAGWorker.generate\",\n           deps=[],\n           node_type=NodeType.MODEL_INFERENCE,\n           node_role=NodeRole.ROLLOUT\n       ).add_node(\n           \"my_custom_node\",\n           func=\"my_module:my_function\",\n           deps=[\"rollout_actor\"],\n           node_type=NodeType.COMPUTE,\n           node_role=NodeRole.DEFAULT\n       )\n\n       return pipeline.build()\n\nCustom Node Function\n~~~~~~~~~~~~~~~~~~~~\n\nNode functions must follow this signature:\n\n.. code-block:: python\n\n   from siirl.dag_worker.data_structures import NodeOutput\n\n   def my_function(batch, config=None, **kwargs) -> NodeOutput:\n       \"\"\"\n       Args:\n           batch: Input data (TensorDict)\n           config: Global configuration\n           **kwargs: Additional arguments\n\n       Returns:\n           NodeOutput(batch=processed_batch, metrics={})\n       \"\"\"\n       # Process batch\n       processed_batch = process(batch)\n\n       # Collect metrics\n       metrics = {\"metric_name\": value}\n\n       return NodeOutput(batch=processed_batch, metrics=metrics)\n\nUse Custom Pipeline\n~~~~~~~~~~~~~~~~~~~\n\n**Command Line:**\n\n.. code-block:: bash\n\n   python -m siirl.main_dag \\\n     dag.custom_pipeline_fn=\"my_module:my_custom_pipeline\"\n\n\n"
  },
  {
    "path": "docs/user_interface/reward_interface.rst",
    "content": "================\nReward Interface\n================\n\nCustom reward functions allow you to score model-generated responses. Simply write a Python function and specify its path in configuration.\n\n**Official Example:** ``siirl/user_interface/rewards_interface/custom_gsm8k_reward.py``\n\nArchitecture Overview\n---------------------\n\n::\n\n                           Reward Computation Flow\n   ==============================================================================\n\n   +------------------+     +-------------------+     +------------------+\n   |  Rollout Node    |     |   Reward Node     |     | Advantage Node   |\n   |  (Generation)    |---->|   (Scoring)       |---->|  (Normalization) |\n   +------------------+     +-------------------+     +------------------+\n                                    |\n                                    v\n                            +---------------+\n                            | RewardManager |\n                            +---------------+\n                                    |\n           +------------------------+------------------------+\n           |                        |                        |\n           v                        v                        v\n   +---------------+        +---------------+        +---------------+\n   | Naive Reward  |        | Batch Reward  |        | Custom Reward |\n   | (Rule-based)  |        | (Model-based) |        | (User-defined)|\n   +---------------+        +---------------+        +---------------+\n                                                            |\n                                                            v\n                                                    +---------------+\n                                                    | compute_score |\n                                                    | (data_source, |\n                                                    |  solution_str,|\n                                                    |  ground_truth,|\n                                                    |  extra_info)  |\n                                                    +-------+-------+\n                                                            |\n                                                            v\n                                                    +---------------+\n                                                    | Returns float |\n                                                    | score [0, 1]  |\n                                                    +---------------+\n\n   ==============================================================================\n\n   Custom Reward Function Integration:\n\n   Configuration                          Runtime\n   +---------------------------+          +---------------------------+\n   | custom_reward_function:   |          | RewardManager loads       |\n   |   path: /path/to/file.py  |  ----->  | compute_score function    |\n   |   name: compute_score     |          | and calls it per sample   |\n   +---------------------------+          +---------------------------+\n\nQuick Start\n-----------\n\nStep 1: Write Reward Function\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nCreate a Python file with ``compute_score`` function:\n\n.. code-block:: python\n\n   # my_reward.py\n\n   def compute_score(data_source, solution_str, ground_truth, extra_info):\n       \"\"\"\n       Custom reward function\n\n       Args:\n           data_source (str): Dataset source identifier (e.g., \"openai/gsm8k\")\n           solution_str (str): Model generated text\n           ground_truth (str): Correct answer\n           extra_info (dict): Additional information (optional)\n\n       Returns:\n           float: Score (typically 0-1)\n       \"\"\"\n       # Your scoring logic\n       if solution_str == ground_truth:\n           return 1.0\n       else:\n           return 0.0\n\nStep 2: Configuration\n~~~~~~~~~~~~~~~~~~~~~\n\n**Command Line:**\n\n.. code-block:: bash\n\n   python -m siirl.main_dag \\\n     custom_reward_function.path=/path/to/my_reward.py \\\n     custom_reward_function.name=compute_score\n\nOfficial Example: GSM8K\n-----------------------\n\n**File:** ``siirl/user_interface/rewards_interface/custom_gsm8k_reward.py``\n\n.. code-block:: python\n\n   import re\n\n   def extract_solution(solution_str, method=\"strict\"):\n       \"\"\"Extract answer from solution\"\"\"\n       if method == \"strict\":\n           # Requires #### <answer> format\n           solution = re.search(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n           if solution is None:\n               return None\n           final_answer = solution.group(1).replace(\",\", \"\")\n           return final_answer\n       elif method == \"flexible\":\n           # Extract last number\n           answer = re.findall(\"(\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n           if len(answer) == 0:\n               return None\n           for final_answer in reversed(answer):\n               if final_answer not in [\"\", \".\"]:\n                   return final_answer\n       return None\n\n   def compute_score(data_source, solution_str, ground_truth, extra_info):\n       \"\"\"\n       GSM8K scoring function\n\n       Checks format and compares answer\n       \"\"\"\n       method = \"strict\"\n       format_score = 0.0\n       score = 1.0\n\n       answer = extract_solution(solution_str, method=method)\n\n       if answer is None:\n           return 0  # Format error\n       elif answer == ground_truth:\n           return score  # Correct answer\n       else:\n           return format_score  # Correct format but wrong answer\n\n**Usage:**\n\n.. code-block:: bash\n\n   python -m siirl.main_dag \\\n     custom_reward_function.path=siirl/user_interface/rewards_interface/custom_gsm8k_reward.py \\\n     custom_reward_function.name=compute_score \\\n     data.train_files=/path/to/gsm8k.parquet\n\nCustom Examples\n---------------\n\nExample 1: Keyword Matching\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: python\n\n   def compute_score(data_source, solution_str, ground_truth, extra_info):\n       \"\"\"Keyword-based reward\"\"\"\n       score = 0.0\n\n       # Check keywords\n       keywords = [\"because\", \"therefore\", \"thus\"]\n       for keyword in keywords:\n           if keyword in solution_str.lower():\n               score += 0.3\n\n       # Length check\n       words = len(solution_str.split())\n       if 50 <= words <= 200:\n           score += 0.4\n\n       return min(score, 1.0)\n\nExample 2: Regex Validation\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: python\n\n   import re\n\n   def compute_score(data_source, solution_str, ground_truth, extra_info):\n       \"\"\"Regex-based format validation\"\"\"\n       # Extract numeric answer\n       match = re.search(r\"答案[是为][:：]\\s*(\\d+)\", solution_str)\n\n       if match is None:\n           return 0.0  # Incorrect format\n\n       answer = match.group(1)\n       if answer == ground_truth:\n           return 1.0  # Correct\n       else:\n           return 0.1  # Correct format but wrong answer\n\nExample 3: Multi-stage Scoring\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: python\n\n   import re\n\n   def compute_score(data_source, solution_str, ground_truth, extra_info):\n       \"\"\"Multi-stage scoring: format + reasoning + correctness\"\"\"\n       score = 0.0\n\n       # Stage 1: Format check (0.2 points)\n       if \"####\" in solution_str:\n           score += 0.2\n\n       # Stage 2: Reasoning steps (0.3 points)\n       steps = solution_str.count('\\n')\n       if steps >= 3:\n           score += 0.3\n\n       # Stage 3: Answer correctness (0.5 points)\n       answer_match = re.search(r\"#### ([\\-0-9\\.]+)\", solution_str)\n       if answer_match:\n           answer = answer_match.group(1)\n           if answer == ground_truth:\n               score += 0.5\n\n       return score\n\nExample 4: Multiple Datasets\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. code-block:: python\n\n   def compute_score(data_source, solution_str, ground_truth, extra_info):\n       \"\"\"Route to different scoring functions based on data_source\"\"\"\n       if data_source == \"gsm8k\":\n           return score_gsm8k(solution_str, ground_truth)\n       elif data_source == \"math\":\n           return score_math(solution_str, ground_truth)\n       else:\n           return 0.0\n\n   def score_gsm8k(solution_str, ground_truth):\n       # GSM8K specific logic\n       pass\n\n   def score_math(solution_str, ground_truth):\n       # MATH specific logic\n       pass\n\nFunction Specification\n----------------------\n\nRequired Signature\n~~~~~~~~~~~~~~~~~~\n\n.. code-block:: python\n\n   def compute_score(data_source, solution_str, ground_truth, extra_info):\n       \"\"\"\n       Args:\n           data_source (str): Dataset source\n           solution_str (str): Model generated response\n           ground_truth (str): Correct answer\n           extra_info (dict): Additional information\n\n       Returns:\n           float: Score value\n       \"\"\"\n       pass\n\nImportant Notes\n~~~~~~~~~~~~~~~\n\n1. **Function Name:** Can be customized, specify via ``custom_reward_function.name``\n2. **Return Type:** Must return ``float``, typically in [0, 1] range\n3. **Error Handling:** Recommended to catch exceptions and return default value (e.g., 0.0)\n4. **Parameter Order:** Must follow the signature order\n"
  },
  {
    "path": "examples/cpgd_trainer/run_qwen2_5-7b.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=deepscaler\nexport ALG=cpgd\nexport MODEL_NAME=qwen2.5-7b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen2.5-7B-Instruct\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=512\nexport PPO_MINI_BATCH_SIZE_PER_NODE=256\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=8\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=4096\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.6\nexport ROLLOUT_TP=2\nexport ROLLOUT_N=8\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.actor.policy_drift_coeff=0.001\n    actor_rollout_ref.actor.policy_loss.loss_mode=cpgd\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=False\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=8192\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.weight_factor_in_cpgd='STD_weight'\n    algorithm.kl_ctrl.kl_coef=0.001\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n\n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/cpgd_trainer/run_qwen2_5_vl-72b.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=mm_eureka\nexport ALG=cpgd\nexport MODEL_NAME=qwen2.5-vl-72b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen2.5-VL-72B-Instruct\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=512\nexport PPO_MINI_BATCH_SIZE_PER_NODE=128\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=8\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=4096\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.6\nexport ROLLOUT_TP=8\nexport ROLLOUT_N=8\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.actor.policy_drift_coeff=0.001\n    actor_rollout_ref.actor.policy_loss.loss_mode=cpgd\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=8192\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.weight_factor_in_cpgd='STD_weight'\n    algorithm.kl_ctrl.kl_coef=0.001\n    algorithm.use_kl_in_reward=False\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.del_local_ckpt_after_load=False\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_IFNAME=bond0\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n\n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/cpgd_trainer/run_qwen2_5_vl-7b.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=mm_eureka\nexport ALG=cpgd\nexport MODEL_NAME=qwen2.5-vl-7b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen2.5-VL-7B-Instruct\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=512\nexport PPO_MINI_BATCH_SIZE_PER_NODE=256\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=8\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=4096\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.6\nexport ROLLOUT_TP=2\nexport ROLLOUT_N=8\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.actor.policy_drift_coeff=0.001\n    actor_rollout_ref.actor.policy_loss.loss_mode=cpgd\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=False\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=8192\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.weight_factor_in_cpgd='STD_weight'\n    algorithm.kl_ctrl.kl_coef=0.001\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n\n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/cpgd_trainer/run_qwen3-1.7b.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=deepscaler\nexport ALG=cpgd\nexport MODEL_NAME=qwen3-1.7b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen3-1.7B\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=512\nexport PPO_MINI_BATCH_SIZE_PER_NODE=256\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=8\nexport MAX_PROMPT_LENGTH=1024\nexport MAX_RESPONSE_LENGTH=2048\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.5\nexport ROLLOUT_TP=1\nexport ROLLOUT_N=8\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.actor.policy_drift_coeff=0.001\n    actor_rollout_ref.actor.policy_loss.loss_mode=cpgd\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=8192\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.weight_factor_in_cpgd='STD_weight'\n    algorithm.kl_ctrl.kl_coef=0.001\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n\n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/cpgd_trainer/run_qwen3-8b.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=deepscaler\nexport ALG=cpgd\nexport MODEL_NAME=qwen3-8b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen3-8B\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=512\nexport PPO_MINI_BATCH_SIZE_PER_NODE=256\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=8\nexport MAX_PROMPT_LENGTH=1024\nexport MAX_RESPONSE_LENGTH=2048\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.5\nexport ROLLOUT_TP=2\nexport ROLLOUT_N=8\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.actor.policy_drift_coeff=0.001\n    actor_rollout_ref.actor.policy_loss.loss_mode=cpgd\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=8192\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.weight_factor_in_cpgd='STD_weight'\n    algorithm.kl_ctrl.kl_coef=0.001\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n\n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/custom_pipeline_example/custom_grpo.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nCustom Pipeline Examples\n\nThis file demonstrates how users can define custom training pipelines\nusing the new Pipeline API. All functions are explicitly visible in the code.\n\"\"\"\n\nimport numpy as np\nfrom siirl.execution.dag.pipeline import Pipeline, NodeConfig\nfrom siirl.execution.dag.task_graph import TaskGraph\nfrom tensordict import TensorDict\nfrom siirl.dag_worker.data_structures import NodeOutput\n\n\n# ============================================================================\n# Example 1: Use Built-in Pipeline (Simplest)\n# ============================================================================\n\ndef example_builtin_grpo() -> TaskGraph:\n    \"\"\"\n    Simplest way: Use built-in GRPO pipeline directly.\n\n    This is recommended for users who want to use standard algorithms\n    without customization.\n    \"\"\"\n    from siirl.execution.dag.builtin_pipelines import grpo_pipeline\n    return grpo_pipeline()\n\n\n# ============================================================================\n# Example 2: GRPO with Custom Reward Function\n# ============================================================================\n\ndef grpo_with_custom_reward() -> TaskGraph:\n    \"\"\"\n    Customize the reward computation while keeping other parts standard.\n\n    This example shows how to replace the reward node with a custom function\n    while keeping the rest of the pipeline standard.\n    \"\"\"\n    pipeline = Pipeline(\n        \"grpo_custom_reward\",\n        \"GRPO pipeline with custom reward function\"\n    )\n\n    # Standard rollout\n    pipeline.add_node(\n        \"rollout_actor\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.generate\",\n        deps=[]\n    )\n\n    # Custom reward function (user's own implementation)\n    pipeline.add_node(\n        \"custom_reward\",\n        func=\"examples.custom_pipeline_example.custom_grpo:my_custom_reward_fn\",\n        deps=[\"rollout_actor\"]\n    )\n\n    # Standard advantage calculation and training\n    pipeline.add_node(\n        \"calculate_advantages\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_advantage\",\n        deps=[\"custom_reward\"]\n    ).add_node(\n        \"actor_old_log_prob\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob\",\n        deps=[\"calculate_advantages\"],\n        only_forward_compute=True\n    ).add_node(\n        \"reference_log_prob\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob\",\n        deps=[\"actor_old_log_prob\"]\n    ).add_node(\n        \"actor_train\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.train_actor\",\n        deps=[\"reference_log_prob\"]\n    )\n\n    return pipeline.build()\n\n\ndef my_custom_reward_fn(batch: TensorDict, **kwargs) -> NodeOutput:\n    \"\"\"\n    User's custom reward function.\n\n    This function can implement any custom reward logic.\n    Here we show a simple example, but users can implement\n    arbitrarily complex reward computations.\n\n    Args:\n        batch: TensorDict containing prompts and responses\n        **kwargs: Additional arguments (config, etc.)\n\n    Returns:\n        NodeOutput: Batch with computed rewards\n    \"\"\"\n    # Option 1: Use built-in reward computation as base\n    from siirl.execution.scheduler.reward import compute_reward\n    reward_output = compute_reward(batch, kwargs.get(\"config\"))\n\n    # Option 2: Fully custom reward logic\n    # responses = batch.non_tensor_batch.get(\"responses\", [])\n    # custom_rewards = np.array([score_response(r) for r in responses])\n    # batch.batch[\"rewards\"] = custom_rewards\n    # reward_output = NodeOutput(batch=batch, metrics={\"avg_reward\": custom_rewards.mean()})\n\n    return reward_output\n\n"
  },
  {
    "path": "examples/custom_reward/rewardfunc_gsm8k.py",
    "content": "import re\n\n\ndef extract_solution(solution_str, method=\"strict\"):\n    assert method in [\"strict\", \"flexible\"]\n\n    if method == \"strict\":\n        # this also tests the formatting of the model\n        solution = re.search(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n        if solution is None:\n            final_answer = None\n        else:\n            final_answer = solution.group(0)\n            final_answer = final_answer.split(\"#### \")[1].replace(\",\", \"\").replace(\"$\", \"\")\n    elif method == \"flexible\":\n        answer = re.findall(\"(\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n        final_answer = None\n        if len(answer) == 0:\n            # no reward is there is no answer\n            pass\n        else:\n            invalid_str = [\"\", \".\"]\n            # find the last number that is not '.'\n            for final_answer in reversed(answer):\n                if final_answer not in invalid_str:\n                    break\n    return final_answer\n\n\ndef compute_score(data_sources, solution_strs, ground_truths, extra_infos, method=\"strict\", format_score=0.0, score=1.0, **kwargs):\n    \"\"\"The scoring function for GSM8k.\n\n    Reference: Trung, Luong, et al. \"Reft: Reasoning with reinforced fine-tuning.\" Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.\n\n    Args:\n        data_sources: a list of data sources\n        solution_strs: a list of solution texts\n        ground_truths: a list of ground truths\n        extra_infos: a list of extra infos\n        method: the method to extract the solution, choices are 'strict' and 'flexible'\n        format_score: the score for the format\n        score: the score for the correct answer\n    \"\"\"\n    scores = []\n    for solution_str, ground_truth in zip(solution_strs, ground_truths):\n        answer = extract_solution(solution_str=solution_str, method=method)\n        if answer is None:\n            scores.append(0)\n        else:\n            if answer == ground_truth:\n                scores.append(score)\n            else:\n                scores.append(format_score)\n    return scores"
  },
  {
    "path": "examples/custom_reward/run_qwen2_5-7b-custom_reward.sh",
    "content": "#!/usr/bin/env bash\n# Exit immediately if a command exits with a non-zero status.\nset -e\nset -o pipefail\n# Print commands and their arguments as they are executed for easy debugging.\nset -x\n\n# --- Environment Setup ---\n# bash /root/install_siirl.sh\n\n# Generate a timestamp for unique directory/file names.\ntimestamp=$(date +\"%Y%m%d_%H%M%S\")\n\n# Force stop any existing Ray cluster to ensure a clean start.\nray stop --force\n\n# --- Path and Environment Variable Definitions ---\n\n# Define environment variables for data, model, and checkpoint storage paths.\nexport DATASET=gsm8k\nexport ALG=grpo\nexport MODEL_NAME=qwen2.5-7b\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen2.5-7B-Instruct\nexport CKPT_PATH=ckpts/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_$PET_NNODES\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\n\n# Environment variables for Gloo (used for distributed communication).\n#export GLOO_SOCKET_IFNAME=bond0\nexport GLOO_SOCKET_TIMEOUT=600\nexport GLOO_TCP_TIMEOUT=600\nexport GLOO_LOG_LEVEL=DEBUG\n\n# Define paths for TensorBoard and logging outputs.\nexport TENSORBOARD_DIR=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${PET_NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${PET_NNODES}_$timestamp\n\n# --- Training Hyperparameters ---\n\nexport TRAIN_BATCH_SIZE_PER_NODE=1024\nexport PPO_MINI_BATCH_SIZE_PER_NODE=256\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=16\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=4096\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.6\nexport ROLLOUT_TP=2\nexport ROLLOUT_N=8\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Cluster Configuration (Usually no changes needed below) ---\n\n# These variables are typically set by the cluster job scheduler (e.g., Slurm, DLC).\nexport N_GPUS_PER_NODE=8\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\nexport VLLM_USE_V1=1\n\n# Calculate the global batch sizes based on the number of nodes.\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# Ray cluster connection settings.\nexport RAY_MASTER_PORT=6379\nexport RAY_DASHBOARD_PORT=8265\nexport RAY_MASTER_ADDR=$MASTER_ADDR\n\n# --- Ray Cluster Start Function (Robust for Large Scale) ---\n\nstart_ray_cluster() {\n    # Set a generous timeout for workers waiting for the head node.\n    local RAY_HEAD_WAIT_TIMEOUT=600 # 10 minutes\n\n    # For stability in large clusters, explicitly set Ray to use the same network interface.\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=$INTERFACE_NAME\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=$INTERFACE_NAME\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    \n    # Increase Ray GCS client connection timeout for stability.\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n    \n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    # Multi-node environment\n    if [ \"$NNODES\" -gt 1 ]; then\n        # Head node logic (rank 0)\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            \n            # The head's address is its own resolved IP\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            \n            ray start --head \\\n                --port=\"$RAY_MASTER_PORT\" \\\n                --dashboard-port=\"$RAY_DASHBOARD_PORT\" \\\n                \"${ray_start_common_opts[@]}\" \\\n                --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            \n            echo \"INFO: Ray head started. Waiting for services to become healthy at $RAY_ADDRESS...\"\n            \n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                local current_time=$(date +%s)\n                local elapsed_time=$((current_time - start_time))\n                if [ \"$elapsed_time\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then\n                    echo \"ERROR: Timed out after ${RAY_HEAD_WAIT_TIMEOUT}s waiting for the local head node services. Exiting.\" >&2\n                    ray stop --force\n                    exit 1\n                fi\n                echo \"Head node services not healthy yet. Retrying in 5 seconds...\"\n                sleep 5\n            done\n            echo \"INFO: Head node services are healthy.\"\n        \n        # Worker node logic (all other ranks)\n        else\n            # The address to connect to is the master node's address from the job scheduler\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head node at $head_node_address...\"\n            \n            local start_time=$(date +%s)\n            # ROBUST CHECK: Use `ray health-check` to wait for the head.\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                local current_time=$(date +%s)\n                local elapsed_time=$((current_time - start_time))\n\n                if [ \"$elapsed_time\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then\n                    echo \"ERROR: Timed out after ${RAY_HEAD_WAIT_TIMEOUT}s waiting for the head node to be healthy. Exiting.\" >&2\n                    exit 1\n                fi\n                \n                echo \"Head node at $head_node_address not healthy yet. Retrying in 5 seconds...\"\n                sleep 5\n            done\n\n            echo \"INFO: Head node is healthy! Worker node $(hostname) is starting and connecting.\"\n            ray start --address=\"$head_node_address\" \\\n                \"${ray_start_common_opts[@]}\" \\\n                --block # Use --block to keep the script running until the worker is stopped.\n        fi\n    # Single-node setup\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n\n# --- Training Launch Function ---\n\nstart_training() {\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        python3 -m siirl.main_dag \\\n            algorithm.adv_estimator=grpo \\\n            data.train_files=$TRAIN_DATA_PATH \\\n            data.val_files=$TEST_DATA_PATH \\\n            data.train_batch_size=$TRAIN_BATCH_SIZE \\\n            data.max_prompt_length=$MAX_PROMPT_LENGTH \\\n            data.max_response_length=$MAX_RESPONSE_LENGTH \\\n            data.filter_overlong_prompts=True \\\n            data.truncation='error' \\\n            data.shuffle=False \\\n            actor_rollout_ref.model.path=$MODEL_PATH \\\n            actor_rollout_ref.actor.optim.lr=1e-6 \\\n            actor_rollout_ref.actor.policy_drift_coeff=0.001 \\\n            actor_rollout_ref.actor.use_cpgd_loss=True \\\n            actor_rollout_ref.model.use_remove_padding=True \\\n            actor_rollout_ref.model.use_fused_kernels=False \\\n            actor_rollout_ref.actor.ppo_mini_batch_size=$PPO_MINI_BATCH_SIZE \\\n            actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$PPO_MICRO_BATCH_SIZE_PER_GPU \\\n            actor_rollout_ref.actor.use_kl_loss=True \\\n            actor_rollout_ref.actor.grad_clip=0.5 \\\n            actor_rollout_ref.actor.clip_ratio=0.2 \\\n            actor_rollout_ref.actor.kl_loss_coef=0.01 \\\n            actor_rollout_ref.actor.kl_loss_type=low_var_kl \\\n            actor_rollout_ref.model.enable_gradient_checkpointing=True \\\n            actor_rollout_ref.actor.fsdp_config.param_offload=False \\\n            actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \\\n            actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$PPO_MICRO_BATCH_SIZE_PER_GPU \\\n            actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \\\n            actor_rollout_ref.rollout.name=vllm \\\n            actor_rollout_ref.rollout.gpu_memory_utilization=$ROLLOUT_GPU_MEMORY_UTILIZATION \\\n            actor_rollout_ref.rollout.max_model_len=8192 \\\n            actor_rollout_ref.rollout.enable_chunked_prefill=False \\\n            actor_rollout_ref.rollout.enforce_eager=False \\\n            actor_rollout_ref.rollout.free_cache_engine=False \\\n            actor_rollout_ref.rollout.n=$ROLLOUT_N \\\n            actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=$PPO_MICRO_BATCH_SIZE_PER_GPU \\\n            actor_rollout_ref.ref.fsdp_config.param_offload=True \\\n            algorithm.kl_ctrl.kl_coef=0.001 \\\n            trainer.critic_warmup=0 \\\n            trainer.logger=['console','tensorboard']  \\\n            trainer.project_name=$PROJECT_NAME \\\n            trainer.experiment_name=$EXPERIMENT_NAME \\\n            trainer.n_gpus_per_node=$N_GPUS_PER_NODE \\\n            trainer.nnodes=$NNODES \\\n            trainer.save_freq=$SAVE_FREQ \\\n            trainer.test_freq=$TEST_FREQ \\\n            trainer.total_epochs=$TOTAL_EPOCHS \\\n            trainer.resume_mode=auto \\\n            trainer.max_actor_ckpt_to_keep=$MAX_CKPT_KEEP \\\n            trainer.default_local_dir=$CKPT_PATH \\\n            trainer.val_before_train=True \\\n            custom_reward_function.path=$HOME/rl/rewardfunc_gsm8k.py \\\n            custom_reward_function.name=compute_score \\\n            reward_model.reward_manager=batch $@\n    fi\n}\n\n# --- Main Execution Logic ---\n\n# Start the Ray cluster (handles both single and multi-node cases).\nstart_ray_cluster\n\n# This logic should only run on the head node (NODE_RANK=0) in a multi-node setup.\nif [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n    echo \"Head node is up. Waiting for all $NNODES nodes to join the cluster...\"\n    TIMEOUT_SECONDS=600\n\n    # This command gets the list of nodes in JSON format and parses it with Python to count them.\n    # 'ray list nodes' is the correct and modern way to get this information from the CLI.\n    get_ready_nodes_cmd='ray list nodes --limit=5000 --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\"'\n    start_time=$(date +%s)\n    \n    # Loop until the number of ready nodes equals the expected number of nodes.\n    while true; do\n        # --- Timeout Check ---\n        current_time=$(date +%s)\n        elapsed_time=$((current_time - start_time))\n        if [ \"$elapsed_time\" -ge \"$TIMEOUT_SECONDS\" ]; then\n            echo \"Error: Timeout! Waited for ${TIMEOUT_SECONDS} seconds, but not all nodes joined.\" >&2\n            exit 1 # Exit with an error code\n        fi\n\n        # Execute the command to get the current count of ready nodes.\n        # '2>/dev/null' suppresses errors if the ray client isn't ready yet, preventing script failure.\n        ready_nodes=$(eval \"$get_ready_nodes_cmd\" 2>/dev/null) || ready_nodes=0\n\n        if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then\n            break # All nodes have joined, exit the loop.\n        fi\n\n        echo \"Waiting for all worker nodes to register... ($ready_nodes / $NNODES nodes ready)\"\n        sleep 2\n    done\n\n    echo \"All $NNODES nodes have successfully joined the cluster.\"\nfi\n\n# --- Script Continuation ---\necho \"Node initialization complete. Continuing with main task...\"\n\n\nstart_training"
  },
  {
    "path": "examples/dapo_trainer/run_qwen2_5-7b.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===                    DAPO                                                                             ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=dapo-math-17k\nexport ALG=grpo  # DAPO uses GRPO (Group Relative Policy Optimization) as the base algorithm\nexport MODEL_NAME=qwen2.5-7b\n\n# --- Path Definitions ---\n# export HOME={your_home_path}\n# export TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\n# export TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\n# export MODEL_PATH=$HOME/data/models/Qwen2.5-VL-7B-Instruct\nexport TRAIN_DATA_PATH=/inspire/hdd/project/qianghuaxuexi/public/datasets/DAPO-Math-17k/dapo-math-17k.parquet\nexport TEST_DATA_PATH=/inspire/hdd/project/qianghuaxuexi/public/datasets/gsm8k/test.parquet\nexport MODEL_PATH=/inspire/hdd/project/qianghuaxuexi/public/models/Qwen3-1.7B\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=512\nexport PPO_MINI_BATCH_SIZE_PER_NODE=256\nexport INFER_MICRO_BATCH_SIZE=8\nexport TRAIN_MICRO_BATCH_SIZE=8\nexport OFFLOAD=False\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=4096\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.7\nexport ROLLOUT_TP=2\nexport ROLLOUT_N=8\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=dapo_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=dapo_${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\nexport MAX_NUM_TOKEN_PER_GPU=$(($MAX_PROMPT_LENGTH + $MAX_RESPONSE_LENGTH))\n\n# --- DAPO-specific Hyperparameters ---\n# Filter groups: Enable dynamic sampling based on trajectory variance\nexport ENABLE_FILTER_GROUPS=True\nexport FILTER_GROUPS_METRIC=acc  # Metric used for filtering (accuracy)\nexport MAX_NUM_GEN_BATCHES=10    # Maximum generation batches before giving up\nexport GEN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * 3))  # Generation batch size (3x training batch per node)\n\n# KL divergence control\nexport USE_KL_IN_REWARD=False    # Whether to use KL penalty in reward\nexport KL_COEF=0.0               # KL coefficient for reward penalty\nexport USE_KL_LOSS=False         # Whether to use KL loss in actor training\nexport KL_LOSS_COEF=0.0          # KL loss coefficient\n\n# PPO clipping parameters for DAPO\nexport CLIP_RATIO_LOW=0.2        # Lower bound for PPO clipping\nexport CLIP_RATIO_HIGH=0.28      # Upper bound for PPO clipping\nexport LOSS_AGG_MODE=\"token-mean\" # Loss aggregation mode\n\n# Overlong sequence handling\nexport ENABLE_OVERLONG_BUFFER=True\nexport OVERLONG_BUFFER_LEN=512\nexport OVERLONG_PENALTY_FACTOR=1.0\n\n# Sampling parameters\nexport TEMPERATURE=1.0           # Sampling temperature\nexport TOP_P=1.0                 # Top-p sampling\nexport TOP_K=-1                  # Top-k sampling (-1 for disabled)\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.gen_batch_size=\\$GEN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='left'\n    data.shuffle=False\n    data.prompt_key=prompt\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10\n    actor_rollout_ref.actor.optim.weight_decay=0.1\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$TRAIN_MICRO_BATCH_SIZE\n    actor_rollout_ref.actor.use_kl_loss=\\$USE_KL_LOSS\n    actor_rollout_ref.actor.kl_loss_coef=\\$KL_LOSS_COEF\n    actor_rollout_ref.actor.grad_clip=1.0\n    actor_rollout_ref.actor.clip_ratio_low=\\$CLIP_RATIO_LOW\n    actor_rollout_ref.actor.clip_ratio_high=\\$CLIP_RATIO_HIGH\n    actor_rollout_ref.actor.clip_ratio_c=10.0\n    actor_rollout_ref.actor.entropy_coeff=0\n    actor_rollout_ref.actor.loss_agg_mode=\\$LOSS_AGG_MODE\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=\\$OFFLOAD\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=\\$OFFLOAD\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$INFER_MICRO_BATCH_SIZE\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.enable_chunked_prefill=True\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH))\n    actor_rollout_ref.rollout.max_model_len=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH))\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.rollout.temperature=\\$TEMPERATURE\n    actor_rollout_ref.rollout.top_p=\\$TOP_P\n    actor_rollout_ref.rollout.top_k=\\$TOP_K\n    actor_rollout_ref.rollout.val_kwargs.temperature=\\$TEMPERATURE\n    actor_rollout_ref.rollout.val_kwargs.top_p=\\$TOP_P\n    actor_rollout_ref.rollout.val_kwargs.top_k=\\$TOP_K\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True\n    actor_rollout_ref.rollout.val_kwargs.n=1\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$INFER_MICRO_BATCH_SIZE\n    actor_rollout_ref.ref.fsdp_config.param_offload=\\$OFFLOAD\n    algorithm.workflow_type=dapo\n    algorithm.use_kl_in_reward=\\$USE_KL_IN_REWARD\n    algorithm.kl_ctrl.kl_coef=\\$KL_COEF\n    algorithm.filter_groups.enable=\\$ENABLE_FILTER_GROUPS\n    algorithm.filter_groups.metric=\\$FILTER_GROUPS_METRIC\n    algorithm.filter_groups.max_num_gen_batches=\\$MAX_NUM_GEN_BATCHES\n    reward_model.reward_manager=dapo\n    reward_model.overlong_buffer.enable=\\$ENABLE_OVERLONG_BUFFER\n    reward_model.overlong_buffer.len=\\$OVERLONG_BUFFER_LEN\n    reward_model.overlong_buffer.penalty_factor=\\$OVERLONG_PENALTY_FACTOR\n    trainer.critic_warmup=0\n    trainer.logger=[\"console\",\"tensorboard\"]\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n\n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/dapo_trainer/run_qwen3-235b-megatron-gspo.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- For config debugging\nexport HYDRA_FULL_ERROR=0\nexport SIIRL_LOG_VERBOSITY=INFO\nexport RAY_DEDUP_LOGS=1\n\n# --- Experiment and Model Definition ---\nexport DATASET=DAPO-Math-17k\nexport MODEL_NAME=qwen3-235b-a22b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/dapo-math-17k.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen3-235B-A22B\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=32  # Conservative for 235B\nexport PPO_MINI_BATCH_SIZE_PER_NODE=32\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=4\nexport MAX_PROMPT_LENGTH=$((1024 * 2))\nexport MAX_RESPONSE_LENGTH=$((1024 * 8))\nexport MAX_MODEL_LENGTH=$((1024 * 10))\n\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.4  # Conservative for 235B\n\nexport ROLLOUT_TP=16  # High TP for 235B\nexport ROLLOUT_N=16\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=15\nexport MAX_CKPT_KEEP=5\n\n# --- GSPO Specific Parameters ---\nexport LOSS_MODE=gspo\nexport ADV_ESTIMATOR=grpo\nexport CLIP_RATIO_LOW=3e-4\nexport CLIP_RATIO_HIGH=4e-4\nexport CLIP_RATIO_C=10.0\nexport LOSS_AGG_MODE=\"token-mean\"\n\n# --- DAPO-specific Hyperparameters ---\n# Filter groups: Enable dynamic sampling based on trajectory variance\nexport ENABLE_FILTER_GROUPS=True\nexport FILTER_GROUPS_METRIC=acc  # Metric used for filtering (accuracy)\nexport MAX_NUM_GEN_BATCHES=10    # Maximum generation batches before giving up\nexport GEN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * 3))  # Generation batch size (3x training batch per node)\n\n# Overlong sequence handling\nexport ENABLE_OVERLONG_BUFFER=True\nexport OVERLONG_BUFFER_LEN=$((1024 * 4))\nexport OVERLONG_PENALTY_FACTOR=1.0\n\n# Sampling parameters\nexport TEMPERATURE=1.0           # Sampling temperature\nexport TOP_P=1.0                 # Top-p sampling\nexport TOP_K=-1                  # Top-k sampling (-1 for disabled)\n\n# --- KL Configuration ---\nexport USE_KL_IN_REWARD=False\nexport KL_COEF=0.0\nexport USE_KL_LOSS=False\nexport KL_LOSS_COEF=0.0\nexport KL_LOSS_TYPE=low_var_kl\n\n# --- Megatron Parallelism for 235B ---\nexport ACTOR_REF_PP=8  # High pipeline parallel for 235B\nexport ACTOR_REF_TP=1  # Low tensor parallel\nexport ACTOR_REF_EP=8  # High expert parallel for MoE\nexport ACTOR_REF_CP=1  # Context parallel\nexport ACTOR_REF_SP=True  # Sequence parallel\n\nexport use_dynamic_bsz=False\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n# Uncomment the following line and set the correct network interface if needed\n# export GLOO_SOCKET_IFNAME=bond0\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${MODEL_NAME}\nexport EXPERIMENT_NAME=siirl_moe_megatron_${MODEL_NAME}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.workflow_type=dapo\n    algorithm.adv_estimator=\\$ADV_ESTIMATOR\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.gen_batch_size=\\$GEN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=True\n    actor_rollout_ref.model.trust_remote_code=True\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.strategy=megatron\n    actor_rollout_ref.actor.use_dynamic_bsz=\\$use_dynamic_bsz\n    # GSPO specific loss configuration\n    actor_rollout_ref.actor.policy_loss.loss_mode=\\$LOSS_MODE\n    actor_rollout_ref.actor.loss_agg_mode=\\$LOSS_AGG_MODE\n    actor_rollout_ref.actor.clip_ratio_low=\\$CLIP_RATIO_LOW\n    actor_rollout_ref.actor.clip_ratio_high=\\$CLIP_RATIO_HIGH\n    actor_rollout_ref.actor.clip_ratio_c=\\$CLIP_RATIO_C\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=\\$use_dynamic_bsz\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=\\$use_dynamic_bsz\n    # Megatron configuration for actor (235B)\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\\$ACTOR_REF_TP\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\\$ACTOR_REF_PP\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=\\$ACTOR_REF_EP\n    actor_rollout_ref.actor.megatron.context_parallel_size=\\$ACTOR_REF_CP\n    actor_rollout_ref.actor.megatron.sequence_parallel=\\$ACTOR_REF_SP\n    actor_rollout_ref.actor.megatron.use_distributed_optimizer=True\n    actor_rollout_ref.actor.megatron.param_dtype=bfloat16\n    actor_rollout_ref.actor.megatron.param_offload=True\n    actor_rollout_ref.actor.megatron.optimizer_offload=True\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=False\n    actor_rollout_ref.actor.megatron.use_mbridge=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32\n    +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform\n    # PPO configuration\n    actor_rollout_ref.actor.policy_drift_coeff=0.001\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=\\$USE_KL_LOSS\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=\\$KL_LOSS_COEF\n    actor_rollout_ref.actor.kl_loss_type=\\$KL_LOSS_TYPE\n    # Rollout configuration (235B)\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=\\$MAX_MODEL_LENGTH\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=True\n    actor_rollout_ref.rollout.free_cache_engine=True\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.rollout.temperature=\\$TEMPERATURE\n    actor_rollout_ref.rollout.top_p=\\$TOP_P\n    actor_rollout_ref.rollout.top_k=\\$TOP_K\n    actor_rollout_ref.rollout.val_kwargs.temperature=\\$TEMPERATURE\n    actor_rollout_ref.rollout.val_kwargs.top_p=\\$TOP_P\n    actor_rollout_ref.rollout.val_kwargs.top_k=\\$TOP_K\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True\n    actor_rollout_ref.rollout.val_kwargs.n=1\n    # Reference model configuration (235B)\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=\\$ACTOR_REF_TP\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=\\$ACTOR_REF_PP\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=\\$ACTOR_REF_EP\n    actor_rollout_ref.ref.megatron.context_parallel_size=\\$ACTOR_REF_CP\n    actor_rollout_ref.ref.megatron.sequence_parallel=\\$ACTOR_REF_SP\n    actor_rollout_ref.ref.megatron.param_offload=True\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=False\n    # Algorithm configuration\n    algorithm.weight_factor_in_cpgd='STD_weight'\n    algorithm.use_kl_in_reward=\\$USE_KL_IN_REWARD\n    algorithm.kl_ctrl.kl_coef=\\$KL_COEF\n    algorithm.filter_groups.enable=\\$ENABLE_FILTER_GROUPS\n    algorithm.filter_groups.metric=\\$FILTER_GROUPS_METRIC\n    algorithm.filter_groups.max_num_gen_batches=\\$MAX_NUM_GEN_BATCHES\n    reward_model.reward_manager=dapo\n    reward_model.overlong_buffer.enable=\\$ENABLE_OVERLONG_BUFFER\n    reward_model.overlong_buffer.len=\\$OVERLONG_BUFFER_LEN\n    reward_model.overlong_buffer.penalty_factor=\\$OVERLONG_PENALTY_FACTOR\n    # Trainer configuration\n    trainer.critic_warmup=0\n    trainer.logger='[\"console\",\"tensorboard\"]'\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=off\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n    dag.enable_perf=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    export NCCL_TIMEOUT=7200\n    export GLOO_TIMEOUT_SECONDS=7200\n\n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting GSPO training command.\"\n        echo \"Command: ${TRAINING_CMD[*]}\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nif [[ \"${BASH_SOURCE[0]}\" == \"${0}\" ]]; then\n    main \"$@\"\nfi\n"
  },
  {
    "path": "examples/dapo_trainer/run_qwen3-8b.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===                    DAPO                                                                             ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=DAPO-Math-17k\nexport ALG=grpo  # DAPO uses GRPO (Group Relative Policy Optimization) as the base algorithm\nexport MODEL_NAME=qwen3-8b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/dapo-math-17k.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen3-8B\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=512\nexport PPO_MINI_BATCH_SIZE_PER_NODE=256\nexport INFER_MICRO_BATCH_SIZE=8\nexport TRAIN_MICRO_BATCH_SIZE=8\nexport OFFLOAD=False\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=8192\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.6\nexport ROLLOUT_TP=2\nexport ROLLOUT_N=8\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=dapo_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=dapo_${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\nexport MAX_NUM_TOKEN_PER_GPU=$(($MAX_PROMPT_LENGTH + $MAX_RESPONSE_LENGTH))\n\n# --- DAPO-specific Hyperparameters ---\n# Filter groups: Enable dynamic sampling based on trajectory variance\nexport ENABLE_FILTER_GROUPS=True\nexport FILTER_GROUPS_METRIC=acc  # Metric used for filtering (accuracy)\nexport MAX_NUM_GEN_BATCHES=10    # Maximum generation batches before giving up\nexport GEN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * 3))  # Generation batch size (3x training batch per node)\n\n# KL divergence control\nexport USE_KL_IN_REWARD=False    # Whether to use KL penalty in reward\nexport KL_COEF=0.0               # KL coefficient for reward penalty\nexport USE_KL_LOSS=False         # Whether to use KL loss in actor training\nexport KL_LOSS_COEF=0.0          # KL loss coefficient\n\n# PPO clipping parameters for DAPO\nexport CLIP_RATIO_LOW=0.2        # Lower bound for PPO clipping\nexport CLIP_RATIO_HIGH=0.28      # Upper bound for PPO clipping\nexport LOSS_AGG_MODE=\"token-mean\" # Loss aggregation mode\n\n# Overlong sequence handling\nexport ENABLE_OVERLONG_BUFFER=True\nexport OVERLONG_BUFFER_LEN=512\nexport OVERLONG_PENALTY_FACTOR=1.0\n\n# Sampling parameters\nexport TEMPERATURE=1.0           # Sampling temperature\nexport TOP_P=1.0                 # Top-p sampling\nexport TOP_K=-1                  # Top-k sampling (-1 for disabled)\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.gen_batch_size=\\$GEN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='left'\n    data.shuffle=False\n    data.prompt_key=prompt\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.actor.optim.lr_warmup_steps=10\n    actor_rollout_ref.actor.optim.weight_decay=0.1\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$TRAIN_MICRO_BATCH_SIZE\n    actor_rollout_ref.actor.use_kl_loss=\\$USE_KL_LOSS\n    actor_rollout_ref.actor.kl_loss_coef=\\$KL_LOSS_COEF\n    actor_rollout_ref.actor.grad_clip=1.0\n    actor_rollout_ref.actor.clip_ratio_low=\\$CLIP_RATIO_LOW\n    actor_rollout_ref.actor.clip_ratio_high=\\$CLIP_RATIO_HIGH\n    actor_rollout_ref.actor.clip_ratio_c=10.0\n    actor_rollout_ref.actor.entropy_coeff=0\n    actor_rollout_ref.actor.loss_agg_mode=\\$LOSS_AGG_MODE\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=\\$OFFLOAD\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=\\$OFFLOAD\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1\n    actor_rollout_ref.actor.ulysses_sequence_parallel_size=1\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$INFER_MICRO_BATCH_SIZE\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.enable_chunked_prefill=True\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.max_num_batched_tokens=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH))\n    actor_rollout_ref.rollout.max_model_len=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH))\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.temperature=\\$TEMPERATURE\n    actor_rollout_ref.rollout.top_p=\\$TOP_P\n    actor_rollout_ref.rollout.top_k=\\$TOP_K\n    actor_rollout_ref.rollout.val_kwargs.temperature=\\$TEMPERATURE\n    actor_rollout_ref.rollout.val_kwargs.top_p=\\$TOP_P\n    actor_rollout_ref.rollout.val_kwargs.top_k=\\$TOP_K\n    actor_rollout_ref.rollout.val_kwargs.do_sample=True\n    actor_rollout_ref.rollout.val_kwargs.n=1\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$INFER_MICRO_BATCH_SIZE\n    actor_rollout_ref.ref.fsdp_config.param_offload=\\$OFFLOAD\n    algorithm.workflow_type=dapo\n    algorithm.use_kl_in_reward=\\$USE_KL_IN_REWARD\n    algorithm.kl_ctrl.kl_coef=\\$KL_COEF\n    algorithm.filter_groups.enable=\\$ENABLE_FILTER_GROUPS\n    algorithm.filter_groups.metric=\\$FILTER_GROUPS_METRIC\n    algorithm.filter_groups.max_num_gen_batches=\\$MAX_NUM_GEN_BATCHES\n    reward_model.reward_manager=dapo\n    reward_model.overlong_buffer.enable=\\$ENABLE_OVERLONG_BUFFER\n    reward_model.overlong_buffer.len=\\$OVERLONG_BUFFER_LEN\n    reward_model.overlong_buffer.penalty_factor=\\$OVERLONG_PENALTY_FACTOR\n    trainer.critic_warmup=0\n    trainer.logger=[\"console\",\"tensorboard\"]\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n\n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/data_preprocess/deepscaler.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nPreprocess the DeepScaleR dataset to parquet format\n\"\"\"\n\nimport argparse\nimport json\nimport os\n\nimport datasets\n\nfrom siirl.utils.extras.hdfs_io import copy, makedirs\n\n\ndef load_json(file_path):\n    with open(file_path, \"r\", encoding=\"utf-8\") as file:\n        dataset = json.load(file)\n        return dataset\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/deepscaler\")\n    parser.add_argument(\"--source_dir\", default=None)\n    parser.add_argument(\"--hdfs_dir\", default=None)\n    parser.add_argument(\"--seed\", default=15)\n\n    args = parser.parse_args()\n\n    data_source = \"agentica-org/DeepScaleR-Preview-Dataset\"\n    instruction_following = \"Let's think step by step and output the final within \\\\boxed{}.\"\n\n    if args.source_dir == None:\n        args.source_dir = data_source\n    raw_dataset = datasets.load_dataset(\"json\", data_files=args.source_dir)\n    full_dataset = raw_dataset[\"train\"]\n    train_test_split_dataset = full_dataset.train_test_split(test_size=0.1, seed=args.seed)\n\n    train_dataset = train_test_split_dataset[\"train\"]\n    test_dataset = train_test_split_dataset[\"test\"]\n\n    def make_map_fn(split_name):\n        def process_fn(example, idx):\n            question_raw = example.pop(\"problem\")\n            answer_raw = example.pop(\"answer\")\n\n            question = question_raw + \" \" + instruction_following\n            solution = example.pop(\"solution\")\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [\n                    {\n                        \"role\": \"user\",\n                        \"content\": question,\n                    }\n                ],\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": answer_raw},\n                \"extra_info\": {\n                    \"split\": split_name,\n                    \"index\": idx,\n                    \"answer\": solution,\n                    \"question\": question_raw,\n                },\n            }\n\n            return data\n\n        return process_fn\n\n    processed_train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    processed_test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True)\n\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n\n    processed_train_dataset.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n    processed_test_dataset.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "examples/data_preprocess/geo3k.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the Geometry3k dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\n\nimport datasets\n\nfrom siirl.utils.extras.hdfs_io import copy, makedirs\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/geo3k\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n\n    args = parser.parse_args()\n\n    data_source = \"hiyouga/geometry3k\"\n\n    dataset = datasets.load_dataset(data_source)\n\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction_following = (\n        r\"You FIRST think about the reasoning process as an internal monologue and then provide the final answer. \"\n        r\"The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE put in \\boxed{}.\"\n    )\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            problem = example.pop(\"problem\")\n            prompt = problem + \" \" + instruction_following\n            answer = example.pop(\"answer\")\n            images = example.pop(\"images\")\n\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [\n                    {\n                        \"role\": \"user\",\n                        \"content\": prompt,\n                    }\n                ],\n                \"images\": images,\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": answer},\n                \"extra_info\": {\n                    \"split\": split,\n                    \"index\": idx,\n                    \"answer\": answer,\n                    \"question\": problem,\n                },\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True, num_proc=8)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True, num_proc=8)\n\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n\n    train_dataset.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "examples/data_preprocess/gsm8k.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the GSM8k dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\nimport re\n\nimport datasets\n\nfrom siirl.utils.extras.hdfs_io import copy, makedirs\n\n\ndef extract_solution(solution_str):\n    solution = re.search(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n    assert solution is not None\n    final_solution = solution.group(0)\n    final_solution = final_solution.split(\"#### \")[1].replace(\",\", \"\")\n    return final_solution\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/gsm8k\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n\n    args = parser.parse_args()\n\n    data_source = \"openai/gsm8k\"\n\n    dataset = datasets.load_dataset(data_source, \"main\")\n\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction_following = 'Let\\'s think step by step and output the final answer after \"####\".'\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            question_raw = example.pop(\"question\")\n\n            question = question_raw + \" \" + instruction_following\n\n            answer_raw = example.pop(\"answer\")\n            solution = extract_solution(answer_raw)\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [\n                    {\n                        \"role\": \"user\",\n                        \"content\": question,\n                    }\n                ],\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": solution},\n                \"extra_info\": {\n                    \"split\": split,\n                    \"index\": idx,\n                    \"answer\": answer_raw,\n                    \"question\": question_raw,\n                },\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True)\n\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n\n    train_dataset.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "examples/data_preprocess/math_dataset.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPreprocess the MATH-lighteval dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\n\nimport datasets\n\nfrom siirl.utils.extras.hdfs_io import copy, makedirs\nfrom siirl.utils.reward_score.math import last_boxed_only_string, remove_boxed\n\n\ndef extract_solution(solution_str):\n    return remove_boxed(last_boxed_only_string(solution_str))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--local_dir\", default=\"~/data/math\")\n    parser.add_argument(\"--hdfs_dir\", default=None)\n\n    args = parser.parse_args()\n\n    # 'lighteval/MATH' is no longer available on huggingface.\n    # Use mirror repo: DigitalLearningGmbH/MATH-lighteval\n    data_source = \"DigitalLearningGmbH/MATH-lighteval\"\n    print(f\"Loading the {data_source} dataset from huggingface...\", flush=True)\n    dataset = datasets.load_dataset(data_source, trust_remote_code=True)\n\n    train_dataset = dataset[\"train\"]\n    test_dataset = dataset[\"test\"]\n\n    instruction_following = \"Let's think step by step and output the final answer within \\\\boxed{}.\"\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            question = example.pop(\"problem\")\n\n            question = question + \" \" + instruction_following\n\n            answer = example.pop(\"solution\")\n            solution = extract_solution(answer)\n            data = {\n                \"data_source\": data_source,\n                \"prompt\": [{\"role\": \"user\", \"content\": question}],\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": solution},\n                \"extra_info\": {\"split\": split, \"index\": idx},\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True)\n\n    local_dir = args.local_dir\n    hdfs_dir = args.hdfs_dir\n\n    train_dataset.to_parquet(os.path.join(local_dir, \"train.parquet\"))\n    test_dataset.to_parquet(os.path.join(local_dir, \"test.parquet\"))\n\n    if hdfs_dir is not None:\n        makedirs(hdfs_dir)\n\n        copy(src=local_dir, dst=hdfs_dir)\n"
  },
  {
    "path": "examples/data_preprocess/mm_eureka.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nPreprocess the MM Eureka dataset to parquet format\n\"\"\"\n\nimport argparse\nimport os\n\nfrom datasets import load_dataset\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--jsonl_file\", type=str)\n    parser.add_argument(\"--output_dir\", type=str, default=\"~/data/mm_eureka/\")\n    parser.add_argument(\"--dataset_name\", type=str, default=\"mm_eureka\")\n    parser.add_argument(\"--nproc\", type=int, default=16)\n    parser.add_argument(\"--test_split\", type=int, default=5, help=\"split percentage of test set\")\n\n    args = parser.parse_args()\n\n    dataset_name = args.dataset_name\n    nproc = args.nproc\n\n    instruct_prompt = \"You should first thinks about the reasoning process in the mind and then provides the user with the answer.\"\n    instruction_following = (\n        r\"You should first thinks about the reasoning process in the mind and then provides the user with the answer. \"\n        r\"Your answer must be in latex format and wrapped in $...$.The reasoning process and answer are enclosed within <think> </think> \"\n        r\"and <answer> </answer> tags, respectively, i.e., <think> Since $1+1=2$, so the answer is $2$. </think><answer> $2$ </answer>, \"\n        r\"which means your output should start with <think> and end with </answer>.\"\n    )\n\n    test_split = args.test_split\n    assert test_split > 0 and test_split < 100\n\n    train_dataset = load_dataset(\"json\", data_files=args.jsonl_file, split=f\"train[:{1 - test_split}%]\")\n    test_dataset = load_dataset(\"json\", data_files=args.jsonl_file, split=f\"train[-{test_split}%:]\")\n\n    # add a row to each data item that represents a unique id\n    def make_map_fn(split):\n        def process_fn(example, idx):\n            id = example.pop(\"id\")\n            conversations = example.pop(\"conversations\")\n            answer = example.pop(\"answer\")\n            image_urls = example.pop(\"image_urls\")\n\n            prompts = []\n            for conv in conversations:\n                if conv[\"role\"] == \"user\":\n                    if instruct_prompt not in conv[\"content\"]:\n                        conv[\"content\"] = instruction_following + \" \" + conv[\"content\"]\n                    prompts.append(conv)\n                # skip other roles such as \"assistant\", \"system\", etc.\n\n            images = []\n            for image_url in image_urls:\n                with open(image_url, \"rb\") as f:\n                    images.append({\"path\": image_url, \"bytes\": f.read()})\n\n            data = {\n                \"data_source\": dataset_name,\n                \"prompt\": prompts,\n                \"images\": images,\n                \"ability\": \"math\",\n                \"reward_model\": {\"style\": \"rule\", \"ground_truth\": answer},\n                \"extra_info\": {\n                    \"id\": id,\n                    \"split\": split,\n                    \"index\": idx,\n                    \"answer\": answer,\n                },\n            }\n            return data\n\n        return process_fn\n\n    train_dataset = train_dataset.map(function=make_map_fn(\"train\"), with_indices=True, num_proc=nproc)\n    test_dataset = test_dataset.map(function=make_map_fn(\"test\"), with_indices=True, num_proc=nproc)\n\n    train_file = os.path.join(args.output_dir, \"train.parquet\")\n    test_file = os.path.join(args.output_dir, \"test.parquet\")\n    train_dataset.to_parquet(train_file)\n    print(f\"Write Done. train file: {train_file}\")\n    test_dataset.to_parquet(test_file)\n    print(f\"Write Done. test file: {test_file}\")\n"
  },
  {
    "path": "examples/embodied_srpo_trainer/run_openvla_oft_libero_goal.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===    Embodied AI SRPO Training with OpenVLA-OFT on LIBERO-GOAL                ===\n# ===================================================================================\n# \n\nset -e\n\n# --- Environment Setup (Critical for siiRL) ---\nexport SIIRL_DIR=\"${SIIRL_DIR:your_siirl_path}\"\nexport PYTHONPATH=\"$SIIRL_DIR:/root/LIBERO/:your_vjepa2_path:$PYTHONPATH\"\n\n# --- Experiment and Model Definition ---\nexport DATASET=libero_goal\nexport ALG=srpo\nexport MODEL_NAME=openvla-oft-7b\nexport MODEL_TYPE=openvla-oft\n\n# --- Path Definitions (USER PROVIDED) ---\nexport HOME_PATH=${HOME_PATH:your_home_path}\nexport TRAIN_DATA_PATH=$HOME_PATH/datasets/vla-oft/libero/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME_PATH/datasets/vla-oft/libero/$DATASET/test.parquet\nexport MODEL_PATH=$HOME_PATH/models/Sylvest/OpenVLA-AC-PD-1traj-libero-goal\nexport VJEPA_MODEL_PATH=$HOME_PATH/models/vjepa2/vitg-384.pt\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Embodied AI Specific Parameters ---\nexport ACTION_TOKEN_LEN=7        # 7 dimensions: xyz (3), quaternion (3), gripper (1)\nexport ACTION_CHUNKS_LEN=8       # OpenVLA-OFT uses 8-step action chunks\nexport NUM_ENVS=16               # actor_rollout_ref.embodied.env.num_envs\nexport MAX_EPISODE_STEPS=512     # actor_rollout_ref.embodied.env.max_steps\n\n# --- Data and Sampling Parameters ---\nexport VAL_BATCH_SIZE=496                      # Validation batch size\nexport MAX_PROMPT_LENGTH=256                   \nexport MAX_RESPONSE_LENGTH=128                 \n\n# --- Embodied Sampling Parameters ---\nexport FILTER_ACCURACY=True                    # Enable accuracy-based filtering\nexport ACCURACY_LOWER_BOUND=0.1                # Only keep prompts with success rate >= 0.1\nexport ACCURACY_UPPER_BOUND=0.9                # Only keep prompts with success rate <= 0.9\nexport FILTER_TRUNCATED=False                  # Filter truncated episodes (uses env.max_steps)\nexport OVERSAMPLE_FACTOR=1                     # Oversample factor for filtering\n\n# --- Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE=64       # data.train_batch_size\nexport PPO_MINI_BATCH_SIZE=4    # actor_rollout_ref.actor.ppo_mini_batch_size\n                                # Note: actual ppo_mini_batch_size = PPO_MINI_BATCH_SIZE * ROLLOUT_N_SAMPLES\nexport ROLLOUT_N_SAMPLES=8       # REUSED: Number of samples per prompt\nexport PPO_EPOCHS=1              # actor_rollout_ref.actor.ppo_epochs\n\n# Algorithm parameters\nexport LEARNING_RATE=5e-6        \nexport WEIGHT_DECAY=0.0          # actor_rollout_ref.actor.optim.weight_decay\nexport CLIP_RATIO_HIGH=0.28      # actor_rollout_ref.actor.clip_ratio_high\nexport CLIP_RATIO_LOW=0.2        # actor_rollout_ref.actor.clip_ratio_low\nexport ENTROPY_COEFF=0.0         \nexport TEMPERATURE=1.6          \nexport GAMMA=1.0                 \nexport LAM=1.0                   \nexport GRAD_CLIP=1.0            \n\n# --- Image/Video Processing ---\nexport IMG_SIZE=384              # actor_rollout_ref.embodied.img_size\nexport ENABLE_FP16=True          # actor_rollout_ref.embodied.enable_fp16\nexport EMBEDDING_MODEL_OFFLOAD=False  # actor_rollout_ref.embodied.embedding_model_offload\nexport CENTER_CROP=True          # actor_rollout_ref.embodied.center_crop\nexport NUM_IMAGES_IN_INPUT=1     \nexport NUM_STEPS_WAIT=10           # Environment stabilization steps\n\n# --- Trainer Configuration ---\nexport SAVE_FREQ=5              \nexport TEST_FREQ=5              \nexport TOTAL_EPOCHS=1000         # trainer.total_epochs\nexport MAX_CKPT_KEEP=5           # trainer.max_actor_ckpt_to_keep\nexport VAL_BEFORE_TRAIN=True     # trainer.val_before_train\n\n# --- Multi-node distributed training ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\nexport MASTER_PORT=${MASTER_PORT:-29500}\n\n# --- Environment Variables ---\nexport MUJOCO_GL=egl\nexport PYOPENGL_PLATFORM=egl\nexport GLOO_SOCKET_TIMEOUT=600\n\n# --- Output Paths and Experiment Naming ---\ntimestamp=$(date +%Y%m%d_%H%M%S)\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}nodes\nexport PROJECT_NAME=siirl_embodied_${DATASET}\nexport EXPERIMENT_NAME=openvla_oft_srpo_fsdp\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}/${timestamp}\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_${timestamp}\n\n# --- Define the Training Command ---\nTRAINING_CMD=(\n    \n    python3 -m siirl.main_dag\n    --config-name=embodied_grpo_trainer\n    \n    # Data configuration\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.val_batch_size=\\$VAL_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.dataset_type=embodied  \n\n    # Reward\n    reward_model.reward_manager=embodied\n    reward_model.reward_kwargs.action_token_len=7\n    reward_model.reward_kwargs.reward_coef=5.0\n\n    # Algorithm configuration\n    algorithm.workflow_type=embodied\n    algorithm.adv_estimator=grpo\n    algorithm.gamma=\\$GAMMA\n    algorithm.lam=\\$LAM\n    algorithm.norm_adv_by_std_in_grpo=True\n    \n    # Embodied sampling configuration (aligned with DAPO architecture)\n    algorithm.filter_groups.enable=True\n    algorithm.embodied_sampling.filter_accuracy=\\$FILTER_ACCURACY\n    algorithm.embodied_sampling.accuracy_lower_bound=\\$ACCURACY_LOWER_BOUND\n    algorithm.embodied_sampling.accuracy_upper_bound=\\$ACCURACY_UPPER_BOUND\n    algorithm.embodied_sampling.filter_truncated=\\$FILTER_TRUNCATED\n    algorithm.embodied_sampling.oversample_factor=\\$OVERSAMPLE_FACTOR\n    \n    # Model configuration\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.model.model_type=embodied\n    actor_rollout_ref.model.trust_remote_code=True\n\n    # Actor configuration\n    actor_rollout_ref.actor.optim.lr=\\$LEARNING_RATE\n    actor_rollout_ref.actor.optim.weight_decay=\\$WEIGHT_DECAY\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_epochs=\\$PPO_EPOCHS\n    actor_rollout_ref.actor.grad_clip=\\$GRAD_CLIP\n    actor_rollout_ref.actor.clip_ratio_high=\\$CLIP_RATIO_HIGH\n    actor_rollout_ref.actor.clip_ratio_low=\\$CLIP_RATIO_LOW\n    actor_rollout_ref.actor.entropy_coeff=\\$ENTROPY_COEFF\n    actor_rollout_ref.actor.shuffle=True\n    \n    # Actor FSDP configuration\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.grad_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    \n    # Rollout configuration\n    actor_rollout_ref.rollout.name=hf\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N_SAMPLES\n    actor_rollout_ref.rollout.temperature=\\$TEMPERATURE\n    actor_rollout_ref.rollout.do_sample=True\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=512\n    \n    # Embodied AI specific configuration\n    actor_rollout_ref.embodied.embodied_type=\\$MODEL_TYPE\n    actor_rollout_ref.embodied.action_token_len=\\$ACTION_TOKEN_LEN\n    actor_rollout_ref.embodied.action_chunks_len=\\$ACTION_CHUNKS_LEN\n    actor_rollout_ref.embodied.video_embedding_model_path=\\$VJEPA_MODEL_PATH\n    actor_rollout_ref.embodied.embedding_img_size=\\$IMG_SIZE\n    actor_rollout_ref.embodied.embedding_enable_fp16=\\$ENABLE_FP16\n    actor_rollout_ref.embodied.embedding_model_offload=\\$EMBEDDING_MODEL_OFFLOAD\n    actor_rollout_ref.embodied.center_crop=\\$CENTER_CROP\n    actor_rollout_ref.embodied.num_images_in_input=\\$NUM_IMAGES_IN_INPUT\n    actor_rollout_ref.embodied.unnorm_key=\\$DATASET\n    \n    # Environment configuration\n    actor_rollout_ref.embodied.env.env_type=libero\n    actor_rollout_ref.embodied.env.env_name=\\$DATASET\n    actor_rollout_ref.embodied.env.num_envs=\\$NUM_ENVS\n    actor_rollout_ref.embodied.env.max_steps=\\$MAX_EPISODE_STEPS\n    actor_rollout_ref.embodied.env.num_steps_wait=\\$NUM_STEPS_WAIT\n    actor_rollout_ref.embodied.env.num_trials_per_task=50\n    actor_rollout_ref.embodied.env.model_family=openvla\n    \n    # Critic configuration (SRPO doesn't use critic)\n    critic.use_critic_model=False\n    \n    # Trainer configuration\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.nnodes=\\$NNODES\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.resume_mode=auto\n    trainer.val_before_train=\\$VAL_BEFORE_TRAIN\n)\n\n# ===================================================================================\n# ===                          EXECUTION LOGIC                                    ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/embodied_srpo_trainer/run_openvla_oft_libero_long.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===    Embodied AI SRPO Training with OpenVLA-OFT on LIBERO-10               ===\n# ===================================================================================\n# \n\nset -e\n\n# --- Environment Setup (Critical for siiRL) ---\nexport SIIRL_DIR=\"${SIIRL_DIR:your_siirl_path}\"\nexport PYTHONPATH=\"$SIIRL_DIR:/root/LIBERO/:your_vjepa2_path:$PYTHONPATH\"\n\n# --- Experiment and Model Definition ---\nexport DATASET=libero_10\nexport ALG=srpo\nexport MODEL_NAME=openvla-oft-7b\nexport MODEL_TYPE=openvla-oft\n\n# --- Path Definitions (USER PROVIDED) ---\nexport HOME_PATH=${HOME_PATH:your_home_path}\nexport TRAIN_DATA_PATH=$HOME_PATH/datasets/vla-oft/libero/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME_PATH/datasets/vla-oft/libero/$DATASET/test.parquet\nexport MODEL_PATH=$HOME_PATH/models/Sylvest/OpenVLA-AC-PD-1traj-libero-long\nexport VJEPA_MODEL_PATH=$HOME_PATH/models/vjepa2/vitg-384.pt\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Embodied AI Specific Parameters ---\nexport ACTION_TOKEN_LEN=7        # 7 dimensions: xyz (3), quaternion (3), gripper (1)\nexport ACTION_CHUNKS_LEN=8       # OpenVLA-OFT uses 8-step action chunks\nexport NUM_ENVS=16               # actor_rollout_ref.embodied.env.num_envs\nexport MAX_EPISODE_STEPS=512     # actor_rollout_ref.embodied.env.max_steps\n\n# --- Data and Sampling Parameters ---\nexport VAL_BATCH_SIZE=496                      # Validation batch size\nexport MAX_PROMPT_LENGTH=256                   \nexport MAX_RESPONSE_LENGTH=128                 \n\n# --- Embodied Sampling Parameters ---\nexport FILTER_ACCURACY=True                    # Enable accuracy-based filtering\nexport ACCURACY_LOWER_BOUND=0.1                # Only keep prompts with success rate >= 0.1\nexport ACCURACY_UPPER_BOUND=0.9                # Only keep prompts with success rate <= 0.9\nexport FILTER_TRUNCATED=False                  # Filter truncated episodes (uses env.max_steps)\nexport OVERSAMPLE_FACTOR=1                     # Oversample factor for filtering\n\n# --- Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE=64       # data.train_batch_size\nexport PPO_MINI_BATCH_SIZE=4    # actor_rollout_ref.actor.ppo_mini_batch_size\n                                # Note: actual ppo_mini_batch_size = PPO_MINI_BATCH_SIZE * ROLLOUT_N_SAMPLES\nexport ROLLOUT_N_SAMPLES=8       # REUSED: Number of samples per prompt\nexport PPO_EPOCHS=1              # actor_rollout_ref.actor.ppo_epochs\n\n# Algorithm parameters\nexport LEARNING_RATE=5e-6        \nexport WEIGHT_DECAY=0.0          # actor_rollout_ref.actor.optim.weight_decay\nexport CLIP_RATIO_HIGH=0.28      # actor_rollout_ref.actor.clip_ratio_high\nexport CLIP_RATIO_LOW=0.2        # actor_rollout_ref.actor.clip_ratio_low\nexport ENTROPY_COEFF=0.0         \nexport TEMPERATURE=1.6          \nexport GAMMA=1.0                 \nexport LAM=1.0                   \nexport GRAD_CLIP=1.0            \n\n# --- Image/Video Processing ---\nexport IMG_SIZE=384              # actor_rollout_ref.embodied.img_size\nexport ENABLE_FP16=True          # actor_rollout_ref.embodied.enable_fp16\nexport EMBEDDING_MODEL_OFFLOAD=False  # actor_rollout_ref.embodied.embedding_model_offload\nexport CENTER_CROP=True          # actor_rollout_ref.embodied.center_crop\nexport NUM_IMAGES_IN_INPUT=1     \nexport NUM_STEPS_WAIT=10           # Environment stabilization steps\n\n# --- Trainer Configuration ---\nexport SAVE_FREQ=5              \nexport TEST_FREQ=5              \nexport TOTAL_EPOCHS=1000         # trainer.total_epochs\nexport MAX_CKPT_KEEP=5           # trainer.max_actor_ckpt_to_keep\nexport VAL_BEFORE_TRAIN=True     # trainer.val_before_train\n\n# --- Multi-node distributed training ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\nexport MASTER_PORT=${MASTER_PORT:-29500}\n\n# --- Environment Variables ---\nexport MUJOCO_GL=egl\nexport PYOPENGL_PLATFORM=egl\nexport GLOO_SOCKET_TIMEOUT=600\n\n# --- Output Paths and Experiment Naming ---\ntimestamp=$(date +%Y%m%d_%H%M%S)\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}nodes\nexport PROJECT_NAME=siirl_embodied_${DATASET}\nexport EXPERIMENT_NAME=openvla_oft_srpo_fsdp\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}/${timestamp}\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_${timestamp}\n\n# --- Define the Training Command ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    # Data configuration\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.val_batch_size=\\$VAL_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.dataset_type=embodied  \n\n    # Reward\n    reward_model.reward_manager=embodied\n    reward_model.reward_kwargs.action_token_len=7\n    reward_model.reward_kwargs.reward_coef=5.0\n\n    # Algorithm configuration\n    algorithm.workflow_type=embodied\n    algorithm.adv_estimator=grpo\n    algorithm.gamma=\\$GAMMA\n    algorithm.lam=\\$LAM\n    algorithm.norm_adv_by_std_in_grpo=True\n    \n    # Embodied sampling configuration (aligned with DAPO architecture)\n    algorithm.filter_groups.enable=True\n    algorithm.embodied_sampling.filter_accuracy=\\$FILTER_ACCURACY\n    algorithm.embodied_sampling.accuracy_lower_bound=\\$ACCURACY_LOWER_BOUND\n    algorithm.embodied_sampling.accuracy_upper_bound=\\$ACCURACY_UPPER_BOUND\n    algorithm.embodied_sampling.filter_truncated=\\$FILTER_TRUNCATED\n    algorithm.embodied_sampling.oversample_factor=\\$OVERSAMPLE_FACTOR\n    \n    # Model configuration\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.model.model_type=embodied\n    actor_rollout_ref.model.trust_remote_code=True\n    # Actor configuration\n    actor_rollout_ref.actor.optim.lr=\\$LEARNING_RATE\n    actor_rollout_ref.actor.optim.weight_decay=\\$WEIGHT_DECAY\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_epochs=\\$PPO_EPOCHS\n    actor_rollout_ref.actor.grad_clip=\\$GRAD_CLIP\n    actor_rollout_ref.actor.clip_ratio_c=10000.0\n    actor_rollout_ref.actor.clip_ratio_high=\\$CLIP_RATIO_HIGH\n    actor_rollout_ref.actor.clip_ratio_low=\\$CLIP_RATIO_LOW\n    actor_rollout_ref.actor.entropy_coeff=\\$ENTROPY_COEFF\n    actor_rollout_ref.actor.shuffle=True\n    \n    # Actor FSDP configuration\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.grad_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    \n    # Rollout configuration\n    actor_rollout_ref.rollout.name=hf\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N_SAMPLES\n    actor_rollout_ref.rollout.temperature=\\$TEMPERATURE\n    actor_rollout_ref.rollout.do_sample=True\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=512\n    \n    # Embodied AI specific configuration\n    actor_rollout_ref.embodied.embodied_type=\\$MODEL_TYPE\n    actor_rollout_ref.embodied.action_token_len=\\$ACTION_TOKEN_LEN\n    actor_rollout_ref.embodied.action_chunks_len=\\$ACTION_CHUNKS_LEN\n    actor_rollout_ref.embodied.video_embedding_model_path=\\$VJEPA_MODEL_PATH\n    actor_rollout_ref.embodied.embedding_img_size=\\$IMG_SIZE\n    actor_rollout_ref.embodied.embedding_enable_fp16=\\$ENABLE_FP16\n    actor_rollout_ref.embodied.embedding_model_offload=\\$EMBEDDING_MODEL_OFFLOAD\n    actor_rollout_ref.embodied.center_crop=\\$CENTER_CROP\n    actor_rollout_ref.embodied.num_images_in_input=\\$NUM_IMAGES_IN_INPUT\n    actor_rollout_ref.embodied.unnorm_key=\\$DATASET\n    \n    # Environment configuration\n    actor_rollout_ref.embodied.env.env_type=libero\n    actor_rollout_ref.embodied.env.env_name=\\$DATASET\n    actor_rollout_ref.embodied.env.num_envs=\\$NUM_ENVS\n    actor_rollout_ref.embodied.env.max_steps=\\$MAX_EPISODE_STEPS\n    actor_rollout_ref.embodied.env.num_steps_wait=\\$NUM_STEPS_WAIT\n    actor_rollout_ref.embodied.env.num_trials_per_task=50\n    actor_rollout_ref.embodied.env.model_family=openvla\n    \n    # Critic configuration (SRPO doesn't use critic)\n    critic.use_critic_model=False\n    \n    # Trainer configuration\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.nnodes=\\$NNODES\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.resume_mode=auto\n    trainer.val_before_train=\\$VAL_BEFORE_TRAIN\n)\n\n# ===================================================================================\n# ===                          EXECUTION LOGIC                                    ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/embodied_srpo_trainer/run_openvla_oft_libero_object.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===    Embodied AI SRPO Training with OpenVLA-OFT on LIBERO-OBJECT               ===\n# ===================================================================================\n# \n\nset -e\n\n# --- Environment Setup (Critical for siiRL) ---\nexport SIIRL_DIR=\"${SIIRL_DIR:your_siirl_path}\"\nexport PYTHONPATH=\"$SIIRL_DIR:/root/LIBERO/:your_vjepa2_path:$PYTHONPATH\"\n\n# --- Experiment and Model Definition ---\nexport DATASET=libero_object\nexport ALG=srpo\nexport MODEL_NAME=openvla-oft-7b\nexport MODEL_TYPE=openvla-oft\n\n# --- Path Definitions (USER PROVIDED) ---\nexport HOME_PATH=${HOME_PATH:your_home_path}\nexport TRAIN_DATA_PATH=$HOME_PATH/datasets/vla-oft/libero/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME_PATH/datasets/vla-oft/libero/$DATASET/test.parquet\nexport MODEL_PATH=$HOME_PATH/models/Sylvest/OpenVLA-AC-PD-1traj-libero-object\nexport VJEPA_MODEL_PATH=$HOME_PATH/models/vjepa2/vitg-384.pt\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Embodied AI Specific Parameters ---\nexport ACTION_TOKEN_LEN=7        # 7 dimensions: xyz (3), quaternion (3), gripper (1)\nexport ACTION_CHUNKS_LEN=8       # OpenVLA-OFT uses 8-step action chunks\nexport NUM_ENVS=16               # actor_rollout_ref.embodied.env.num_envs\nexport MAX_EPISODE_STEPS=512     # actor_rollout_ref.embodied.env.max_steps\n\n# --- Data and Sampling Parameters ---\nexport VAL_BATCH_SIZE=496                      # Validation batch size\nexport MAX_PROMPT_LENGTH=256                   \nexport MAX_RESPONSE_LENGTH=128                 \n\n# --- Embodied Sampling Parameters ---\nexport FILTER_ACCURACY=True                    # Enable accuracy-based filtering\nexport ACCURACY_LOWER_BOUND=0.1                # Only keep prompts with success rate >= 0.1\nexport ACCURACY_UPPER_BOUND=0.9                # Only keep prompts with success rate <= 0.9\nexport FILTER_TRUNCATED=False                  # Filter truncated episodes (uses env.max_steps)\nexport OVERSAMPLE_FACTOR=1                     # Oversample factor for filtering\n\n# --- Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE=64       # data.train_batch_size\nexport PPO_MINI_BATCH_SIZE=4    # actor_rollout_ref.actor.ppo_mini_batch_size\n                                # Note: actual ppo_mini_batch_size = PPO_MINI_BATCH_SIZE * ROLLOUT_N_SAMPLES\nexport ROLLOUT_N_SAMPLES=8       # REUSED: Number of samples per prompt\nexport PPO_EPOCHS=1              # actor_rollout_ref.actor.ppo_epochs\n\n# Algorithm parameters\nexport LEARNING_RATE=5e-6        \nexport WEIGHT_DECAY=0.0          # actor_rollout_ref.actor.optim.weight_decay\nexport CLIP_RATIO_HIGH=0.28      # actor_rollout_ref.actor.clip_ratio_high\nexport CLIP_RATIO_LOW=0.2        # actor_rollout_ref.actor.clip_ratio_low\nexport ENTROPY_COEFF=0.0         \nexport TEMPERATURE=1.6          \nexport GAMMA=1.0                 \nexport LAM=1.0                   \nexport GRAD_CLIP=1.0            \n\n# --- Image/Video Processing ---\nexport IMG_SIZE=384              # actor_rollout_ref.embodied.img_size\nexport ENABLE_FP16=True          # actor_rollout_ref.embodied.enable_fp16\nexport EMBEDDING_MODEL_OFFLOAD=False  # actor_rollout_ref.embodied.embedding_model_offload\nexport CENTER_CROP=True          # actor_rollout_ref.embodied.center_crop\nexport NUM_IMAGES_IN_INPUT=1     \nexport NUM_STEPS_WAIT=10           # Environment stabilization steps\n\n# --- Trainer Configuration ---\nexport SAVE_FREQ=5              \nexport TEST_FREQ=5              \nexport TOTAL_EPOCHS=1000         # trainer.total_epochs\nexport MAX_CKPT_KEEP=5           # trainer.max_actor_ckpt_to_keep\nexport VAL_BEFORE_TRAIN=True     # trainer.val_before_train\n\n# --- Multi-node distributed training ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\nexport MASTER_PORT=${MASTER_PORT:-29500}\n\n# --- Environment Variables ---\nexport MUJOCO_GL=egl\nexport PYOPENGL_PLATFORM=egl\nexport GLOO_SOCKET_TIMEOUT=600\n\n# --- Output Paths and Experiment Naming ---\ntimestamp=$(date +%Y%m%d_%H%M%S)\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}nodes\nexport PROJECT_NAME=siirl_embodied_${DATASET}\nexport EXPERIMENT_NAME=openvla_oft_srpo_fsdp\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}/${timestamp}\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_${timestamp}\n\n# --- Define the Training Command ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag \n    # Data configuration\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.val_batch_size=\\$VAL_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.dataset_type=embodied  \n\n    # Reward\n    reward_model.reward_manager=embodied\n    reward_model.reward_kwargs.action_token_len=7\n    reward_model.reward_kwargs.reward_coef=5.0\n\n    # Algorithm configuration\n    algorithm.workflow_type=embodied\n    algorithm.adv_estimator=grpo\n    algorithm.gamma=\\$GAMMA\n    algorithm.lam=\\$LAM\n    algorithm.norm_adv_by_std_in_grpo=True\n    \n    # Embodied sampling configuration (aligned with DAPO architecture)\n    algorithm.filter_groups.enable=True\n    algorithm.embodied_sampling.filter_accuracy=\\$FILTER_ACCURACY\n    algorithm.embodied_sampling.accuracy_lower_bound=\\$ACCURACY_LOWER_BOUND\n    algorithm.embodied_sampling.accuracy_upper_bound=\\$ACCURACY_UPPER_BOUND\n    algorithm.embodied_sampling.filter_truncated=\\$FILTER_TRUNCATED\n    algorithm.embodied_sampling.oversample_factor=\\$OVERSAMPLE_FACTOR\n    \n    # Model configuration\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.model.model_type=embodied\n    actor_rollout_ref.model.trust_remote_code=True\n    # Actor configuration\n    actor_rollout_ref.actor.optim.lr=\\$LEARNING_RATE\n    actor_rollout_ref.actor.optim.weight_decay=\\$WEIGHT_DECAY\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_epochs=\\$PPO_EPOCHS\n    actor_rollout_ref.actor.grad_clip=\\$GRAD_CLIP\n    actor_rollout_ref.actor.clip_ratio_c=10000.0\n    actor_rollout_ref.actor.clip_ratio_high=\\$CLIP_RATIO_HIGH\n    actor_rollout_ref.actor.clip_ratio_low=\\$CLIP_RATIO_LOW\n    actor_rollout_ref.actor.entropy_coeff=\\$ENTROPY_COEFF\n    actor_rollout_ref.actor.shuffle=True\n    \n    # Actor FSDP configuration\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.grad_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    \n    # Rollout configuration\n    actor_rollout_ref.rollout.name=hf\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N_SAMPLES\n    actor_rollout_ref.rollout.temperature=\\$TEMPERATURE\n    actor_rollout_ref.rollout.do_sample=True\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=512\n    \n    # Embodied AI specific configuration\n    actor_rollout_ref.embodied.embodied_type=\\$MODEL_TYPE\n    actor_rollout_ref.embodied.action_token_len=\\$ACTION_TOKEN_LEN\n    actor_rollout_ref.embodied.action_chunks_len=\\$ACTION_CHUNKS_LEN\n    actor_rollout_ref.embodied.video_embedding_model_path=\\$VJEPA_MODEL_PATH\n    actor_rollout_ref.embodied.embedding_img_size=\\$IMG_SIZE\n    actor_rollout_ref.embodied.embedding_enable_fp16=\\$ENABLE_FP16\n    actor_rollout_ref.embodied.embedding_model_offload=\\$EMBEDDING_MODEL_OFFLOAD\n    actor_rollout_ref.embodied.center_crop=\\$CENTER_CROP\n    actor_rollout_ref.embodied.num_images_in_input=\\$NUM_IMAGES_IN_INPUT\n    actor_rollout_ref.embodied.unnorm_key=\\$DATASET\n    \n    # Environment configuration\n    actor_rollout_ref.embodied.env.env_type=libero\n    actor_rollout_ref.embodied.env.env_name=\\$DATASET\n    actor_rollout_ref.embodied.env.num_envs=\\$NUM_ENVS\n    actor_rollout_ref.embodied.env.max_steps=\\$MAX_EPISODE_STEPS\n    actor_rollout_ref.embodied.env.num_steps_wait=\\$NUM_STEPS_WAIT\n    actor_rollout_ref.embodied.env.num_trials_per_task=50\n    actor_rollout_ref.embodied.env.model_family=openvla\n    \n    # Critic configuration (SRPO doesn't use critic)\n    critic.use_critic_model=False\n    \n    # Trainer configuration\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.nnodes=\\$NNODES\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.resume_mode=auto\n    trainer.val_before_train=\\$VAL_BEFORE_TRAIN\n)\n\n# ===================================================================================\n# ===                          EXECUTION LOGIC                                    ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/embodied_srpo_trainer/run_openvla_oft_libero_spatial.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===    Embodied AI SRPO Training with OpenVLA-OFT on LIBERO-SPATIAL             ===\n# ===================================================================================\n# \n\nset -e\n\n# --- Environment Setup (Critical for siiRL) ---\nexport SIIRL_DIR=\"${SIIRL_DIR:your_siirl_path}\"\nexport PYTHONPATH=\"$SIIRL_DIR:/root/LIBERO/:your_vjepa2_path:$PYTHONPATH\"\n\n# --- Experiment and Model Definition ---\nexport DATASET=libero_spatial\nexport ALG=srpo\nexport MODEL_NAME=openvla-oft-7b\nexport MODEL_TYPE=openvla-oft\n\n# --- Path Definitions (USER PROVIDED) ---\nexport HOME_PATH=${HOME_PATH:your_home_path}\nexport TRAIN_DATA_PATH=$HOME_PATH/datasets/vla-oft/libero/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME_PATH/datasets/vla-oft/libero/$DATASET/test.parquet\nexport MODEL_PATH=$HOME_PATH/models/Sylvest/OpenVLA-AC-PD-1traj-libero-spatial\nexport VJEPA_MODEL_PATH=$HOME_PATH/models/vjepa2/vitg-384.pt\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Embodied AI Specific Parameters ---\nexport ACTION_TOKEN_LEN=7        # 7 dimensions: xyz (3), quaternion (3), gripper (1)\nexport ACTION_CHUNKS_LEN=8       # OpenVLA-OFT uses 8-step action chunks\nexport NUM_ENVS=16               # actor_rollout_ref.embodied.env.num_envs\nexport MAX_EPISODE_STEPS=512     # actor_rollout_ref.embodied.env.max_steps\n\n# --- Data and Sampling Parameters ---\nexport VAL_BATCH_SIZE=496                      # Validation batch size\nexport MAX_PROMPT_LENGTH=256                   \nexport MAX_RESPONSE_LENGTH=128                 \n\n# --- Embodied Sampling Parameters ---\nexport FILTER_ACCURACY=True                    # Enable accuracy-based filtering\nexport ACCURACY_LOWER_BOUND=0.1                # Only keep prompts with success rate >= 0.1\nexport ACCURACY_UPPER_BOUND=0.9                # Only keep prompts with success rate <= 0.9\nexport FILTER_TRUNCATED=False                  # Filter truncated episodes (uses env.max_steps)\nexport OVERSAMPLE_FACTOR=1                     # Oversample factor for filtering\n\n# --- Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE=64       # data.train_batch_size\nexport PPO_MINI_BATCH_SIZE=4    # actor_rollout_ref.actor.ppo_mini_batch_size\n                                # Note: actual ppo_mini_batch_size = PPO_MINI_BATCH_SIZE * ROLLOUT_N_SAMPLES\nexport ROLLOUT_N_SAMPLES=8       # REUSED: Number of samples per prompt\nexport PPO_EPOCHS=1              # actor_rollout_ref.actor.ppo_epochs\n\n# Algorithm parameters\nexport LEARNING_RATE=5e-6        \nexport WEIGHT_DECAY=0.0          # actor_rollout_ref.actor.optim.weight_decay\nexport CLIP_RATIO_HIGH=0.28      # actor_rollout_ref.actor.clip_ratio_high\nexport CLIP_RATIO_LOW=0.2        # actor_rollout_ref.actor.clip_ratio_low\nexport ENTROPY_COEFF=0.0         \nexport TEMPERATURE=1.6          \nexport GAMMA=1.0                 \nexport LAM=1.0                   \nexport GRAD_CLIP=1.0            \n\n# --- Image/Video Processing ---\nexport IMG_SIZE=384              # actor_rollout_ref.embodied.img_size\nexport ENABLE_FP16=True          # actor_rollout_ref.embodied.enable_fp16\nexport EMBEDDING_MODEL_OFFLOAD=False  # actor_rollout_ref.embodied.embedding_model_offload\nexport CENTER_CROP=True          # ctor_rollout_ref.embodied.center_crop\nexport NUM_IMAGES_IN_INPUT=1     \nexport NUM_STEPS_WAIT=10           # Environment stabilization steps\n\n# --- Trainer Configuration ---\nexport SAVE_FREQ=5              \nexport TEST_FREQ=5              \nexport TOTAL_EPOCHS=1000         # trainer.total_epochs\nexport MAX_CKPT_KEEP=5           # trainer.max_actor_ckpt_to_keep\nexport VAL_BEFORE_TRAIN=True     # trainer.val_before_train\n\n# --- Multi-node distributed training ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\nexport MASTER_PORT=${MASTER_PORT:-29500}\n\n# --- Environment Variables ---\nexport MUJOCO_GL=egl\nexport PYOPENGL_PLATFORM=egl\nexport GLOO_SOCKET_TIMEOUT=600\n\n# --- Output Paths and Experiment Naming ---\ntimestamp=$(date +%Y%m%d_%H%M%S)\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}nodes\nexport PROJECT_NAME=siirl_embodied_${DATASET}\nexport EXPERIMENT_NAME=openvla_oft_srpo_fsdp\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}/${timestamp}\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_${timestamp}\n\n# --- Define the Training Command ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    # Data configuration\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.val_batch_size=\\$VAL_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.dataset_type=embodied  \n\n    # Reward\n    reward_model.reward_manager=embodied\n    reward_model.reward_kwargs.action_token_len=7\n    reward_model.reward_kwargs.reward_coef=5.0\n\n    # Algorithm configuration\n    algorithm.workflow_type=embodied\n    algorithm.adv_estimator=grpo\n    algorithm.gamma=\\$GAMMA\n    algorithm.lam=\\$LAM\n    algorithm.norm_adv_by_std_in_grpo=True\n    \n    # Embodied sampling configuration (aligned with DAPO architecture)\n    algorithm.filter_groups.enable=True\n    algorithm.embodied_sampling.filter_accuracy=\\$FILTER_ACCURACY\n    algorithm.embodied_sampling.accuracy_lower_bound=\\$ACCURACY_LOWER_BOUND\n    algorithm.embodied_sampling.accuracy_upper_bound=\\$ACCURACY_UPPER_BOUND\n    algorithm.embodied_sampling.filter_truncated=\\$FILTER_TRUNCATED\n    algorithm.embodied_sampling.oversample_factor=\\$OVERSAMPLE_FACTOR\n    \n    # Model configuration\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.model.model_type=embodied\n    actor_rollout_ref.model.trust_remote_code=True\n    # Actor configuration\n    actor_rollout_ref.actor.optim.lr=\\$LEARNING_RATE\n    actor_rollout_ref.actor.optim.weight_decay=\\$WEIGHT_DECAY\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_epochs=\\$PPO_EPOCHS\n    actor_rollout_ref.actor.grad_clip=\\$GRAD_CLIP\n    actor_rollout_ref.actor.clip_ratio_c=10000.0\n    actor_rollout_ref.actor.clip_ratio_high=\\$CLIP_RATIO_HIGH\n    actor_rollout_ref.actor.clip_ratio_low=\\$CLIP_RATIO_LOW\n    actor_rollout_ref.actor.entropy_coeff=\\$ENTROPY_COEFF\n    actor_rollout_ref.actor.shuffle=True\n    \n    # Actor FSDP configuration\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.grad_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    \n    # Rollout configuration\n    actor_rollout_ref.rollout.name=hf\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N_SAMPLES\n    actor_rollout_ref.rollout.temperature=\\$TEMPERATURE\n    actor_rollout_ref.rollout.do_sample=True\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=512\n    \n    # Embodied AI specific configuration\n    actor_rollout_ref.embodied.embodied_type=\\$MODEL_TYPE\n    actor_rollout_ref.embodied.action_token_len=\\$ACTION_TOKEN_LEN\n    actor_rollout_ref.embodied.action_chunks_len=\\$ACTION_CHUNKS_LEN\n    actor_rollout_ref.embodied.video_embedding_model_path=\\$VJEPA_MODEL_PATH\n    actor_rollout_ref.embodied.embedding_img_size=\\$IMG_SIZE\n    actor_rollout_ref.embodied.embedding_enable_fp16=\\$ENABLE_FP16\n    actor_rollout_ref.embodied.embedding_model_offload=\\$EMBEDDING_MODEL_OFFLOAD\n    actor_rollout_ref.embodied.center_crop=\\$CENTER_CROP\n    actor_rollout_ref.embodied.num_images_in_input=\\$NUM_IMAGES_IN_INPUT\n    actor_rollout_ref.embodied.unnorm_key=\\$DATASET\n    \n    # Environment configuration\n    actor_rollout_ref.embodied.env.env_type=libero\n    actor_rollout_ref.embodied.env.env_name=\\$DATASET\n    actor_rollout_ref.embodied.env.num_envs=\\$NUM_ENVS\n    actor_rollout_ref.embodied.env.max_steps=\\$MAX_EPISODE_STEPS\n    actor_rollout_ref.embodied.env.num_steps_wait=\\$NUM_STEPS_WAIT\n    actor_rollout_ref.embodied.env.num_trials_per_task=50\n    actor_rollout_ref.embodied.env.model_family=openvla\n    \n    # Critic configuration (SRPO doesn't use critic)\n    critic.use_critic_model=False\n    \n    # Trainer configuration\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.nnodes=\\$NNODES\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.resume_mode=auto\n    trainer.val_before_train=\\$VAL_BEFORE_TRAIN\n)\n\n# ===================================================================================\n# ===                          EXECUTION LOGIC                                    ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/experimental/marft/config/code_env.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom siirl.utils.reward_score.prime_code import compute_score\nfrom typing import Any, Dict, Optional, Tuple\nimport asyncio\nclass CodeEnv():\n    def __init__(self):\n        pass\n    def reset(self) -> Any:\n        pass\n    async def step(self, actions, ground_truth):\n        actor_action = actions[-1]\n        loop = asyncio.get_event_loop()\n        score, _ = await loop.run_in_executor(\n            None, \n            compute_score, \n            actor_action, ground_truth \n        )\n        score = float(score)\n        should_stop = False\n        if score == 1.0:\n            next_obs = [act + \". This answer is right.\" for act in actions]\n            should_stop = True\n        else:\n            next_obs = [act + \". This answer is wrong.\" for act in actions]\n        return next_obs, score, should_stop\n            \n    "
  },
  {
    "path": "examples/experimental/marft/config/math_env.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nfrom siirl.utils.reward_score.math import compute_score\nfrom typing import Any, Dict, Optional, Tuple\nclass MathEnv():\n    def __init__(self):\n        pass\n    def reset(self) -> Any:\n        pass\n    async def step(self, actions, ground_truth):\n        actor_action = actions[-1]\n        loop = asyncio.get_event_loop()\n        score = await loop.run_in_executor(\n            None, \n            compute_score, \n            actor_action, ground_truth \n        )\n        should_stop = False\n        if score == 1.0:\n            next_obs = [act + \" This answer is right.\" for act in actions]\n            should_stop = True\n        else:\n            next_obs = [act + \" This answer is wrong.\" for act in actions]\n        return next_obs, score, should_stop\n            \n    "
  },
  {
    "path": "examples/experimental/marft/config/process.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom string import Template\ndef pre_process(tokenizer, prompt_id, obs, **kwargs):\n    pre_chat_template = Template(kwargs.get(\"pre_chat_template\", \"\"))\n    prompt = tokenizer.decode(prompt_id)\n    prompt = pre_chat_template.substitute(prompt = prompt)\n    message = [\n        {\"role\": \"system\", \"content\": \"\"},\n        {\"role\": \"user\", \"content\": prompt}\n    ]\n    return tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, add_special_tokens=False)\n\ndef post_process(tokenizer, prompt_id, response_id, **kwargs):\n    post_chat_template = kwargs.get(\"post_chat_template\", None)\n    post_chat_template_id = tokenizer.encode(post_chat_template)   \n    return prompt_id + post_chat_template_id + response_id\n"
  },
  {
    "path": "examples/experimental/marft/config/workflow_marft.yaml",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\ndag_id: \"marft_ppo_training_pipeline\"\ndescription: \"This is MARFT DAG workflow configured via YAML.\"\n\nactor_1_config: &actor1_config\n  rollout.log_prob_micro_batch_size_per_gpu: 16\n  rollout.tensor_model_parallel_size: 4\n  rollout.gpu_memory_utilization: 0.3\n  rollout.n: 1\n  \n\n\nactor_2_config: &actor2_config\n  rollout.log_prob_micro_batch_size_per_gpu: 16\n  rollout.tensor_model_parallel_size: 4\n  rollout.gpu_memory_utilization: 0.3\n  rollout.n: 1\n  \nnodes:\n  - node_id: \"rollout_reasoner\"\n    node_type: \"MODEL_INFERENCE\"\n    node_role: \"ROLLOUT\"\n    config: *actor1_config\n    agent_group: 0\n    dependencies: []\n    agent_options:\n      obs_with_env: true\n      process_path: examples/experimental/marft/config/process.py\n      pre_process_kwargs: \n        pre_chat_template: \"<|im_start|>system: Two LLM agents (Reasoner -> Actor) collaborate step-by-step to solve math problems. You are the **Reasoner**: Analyze the original problem, historical actions, and reflection data (if provided) to determine the critical next step. Guide the Actor by providing concise reasoning for the optimal operation.<|im_end|>\\n <|im_start|> problem: ${prompt} <|im_end|>\\n <|im_start|> reasoner:  \"\n      post_process_kwargs:\n        post_chat_template: \" <|im_start|> reasoner: \"\n\n  - node_id: \"rollout_actor\"\n    node_type: \"MODEL_INFERENCE\"\n    node_role: \"ROLLOUT\"\n    config: *actor2_config\n    agent_group: 1\n    dependencies: \n     - \"rollout_reasoner\"\n    agent_options:\n      obs_with_env: true\n      process_path: examples/experimental/marft/config/process.py\n      pre_process_kwargs: \n        pre_chat_template: \"<|im_start|>system: Two LLM agents (Reasoner -> Actor) collaborate step-by-step. You are the **Actor**: Execute operations using original problem, action history, and Reasoner's guidance. Give the final output within \\\\boxed{}.<|im_end|>\\n ${prompt} <|im_start|> actor: \"\n      post_process_kwargs:\n        post_chat_template: \" <|im_start|> actor: \"\n      env_path: [examples/experimental/marft/config/math_env.py:MathEnv]\n\n  - node_id: \"function_reward\"\n    node_type: \"COMPUTE\"\n    node_role: \"REWARD\"\n    agent_group: 1\n    dependencies:\n      - \"rollout_actor\"\n\n  - node_id: \"actor_1_old_log_prob\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"ACTOR\"\n    only_forward_compute: true\n    agent_group: 0\n    config: *actor1_config    \n    dependencies:\n      - \"function_reward\"\n\n  - node_id: \"actor_2_old_log_prob\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"ACTOR\"\n    only_forward_compute: true\n    agent_group: 1\n    config: *actor2_config    \n    dependencies:\n      - \"actor_1_old_log_prob\"\n\n  - node_id: \"reference_1_log_prob\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"REFERENCE\"\n    agent_group: 0\n    dependencies:\n      - \"actor_2_old_log_prob\"\n\n  - node_id: \"reference_2_log_prob\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"REFERENCE\"\n    agent_group: 1\n    dependencies:\n      - \"reference_1_log_prob\"\n\n  - node_id: \"critic_1_value\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"CRITIC\"\n    agent_group: 0\n    only_forward_compute: true\n    dependencies:\n      - \"reference_2_log_prob\"\n\n  - node_id: \"critic_2_value\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"CRITIC\"\n    agent_group: 1\n    only_forward_compute: true\n    dependencies:\n      - \"critic_1_value\"\n    agent_options:\n      share_instance: 0\n  \n  - node_id: \"calculate_2_advantages\"\n    node_type: \"COMPUTE\"\n    node_role: \"ADVANTAGE\"\n    agent_group: 1\n    dependencies:\n      - \"critic_2_value\"\n\n  - node_id: \"critic_1_train\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"CRITIC\"\n    agent_group: 0\n    dependencies:\n      - \"calculate_2_advantages\"\n\n  - node_id: \"critic_2_train\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"CRITIC\"\n    agent_group: 1\n    dependencies:\n      - \"critic_1_train\"\n    agent_options:\n      share_instance: 0\n\n  - node_id: \"actor_1_train\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"ACTOR\"\n    agent_group: 0\n    config: *actor1_config\n    agent_options:\n      train_cycle: 15\n    dependencies:\n      - \"critic_2_train\"\n\n  - node_id: \"actor_2_train\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"ACTOR\"\n    agent_group: 1\n    config: *actor2_config\n    agent_options:\n      train_cycle: 15\n    dependencies:\n      - \"actor_1_train\"\n\n"
  },
  {
    "path": "examples/experimental/marft/config/workflow_marft_code.yaml",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\ndag_id: \"marft_ppo_training_pipeline\"\ndescription: \"This is MARFT DAG workflow configured via YAML.\"\n\nactor_1_config: &actor1_config\n  rollout.log_prob_micro_batch_size_per_gpu: 16\n  rollout.tensor_model_parallel_size: 4\n  rollout.gpu_memory_utilization: 0.3\n  rollout.n: 1\n  \n\n\nactor_2_config: &actor2_config\n  rollout.log_prob_micro_batch_size_per_gpu: 16\n  rollout.tensor_model_parallel_size: 4\n  rollout.gpu_memory_utilization: 0.3\n  rollout.n: 1\n  \nnodes:\n  - node_id: \"rollout_reasoner\"\n    node_type: \"MODEL_INFERENCE\"\n    node_role: \"ROLLOUT\"\n    config: *actor1_config\n    agent_group: 0\n    dependencies: []\n    agent_options:\n      obs_with_env: true\n      process_path: examples/experimental/marft/config/process.py\n      pre_process_kwargs: \n        pre_chat_template: \"Two LLM agents (Reasoner → Coder) collaborate to solve Codeforces Python coding problems.\\nYou are the **Reasoner**: Analyze the problem statement, constraints, and expected behavior.\\nIdentify edge cases, break the problem into logical steps, and suggest a high-level algorithmic plan.\\nYou may include helpful pseudocode and edge case analysis, but do **not** write actual Python code.\\n<|im_start|>problem: ${prompt}\\n reasoner: \"\n      post_process_kwargs:\n        post_chat_template: \" reasoner: \"\n\n  - node_id: \"rollout_actor\"\n    node_type: \"MODEL_INFERENCE\"\n    node_role: \"ROLLOUT\"\n    config: *actor2_config\n    agent_group: 1\n    dependencies: \n     - \"rollout_reasoner\"\n    agent_options:\n      obs_with_env: true\n      process_path: examples/experimental/marft/config/process.py\n      pre_process_kwargs: \n        pre_chat_template: \"Two LLM agents (Reasoner → Coder) collaborate to solve Codeforces Python coding problems.\\nYou are the **Coder**: Implement the Reasoner's plan using efficient and correct Python code.\\nHandle edge cases, follow the provided strategy, and ensure clarity and correctness.\\nAlways use Python.\\nPlace your complete solution below the line starting with '```python```'.\\n${prompt} coder: \"\n      post_process_kwargs:\n        post_chat_template: \" coder: \"\n      env_path: [examples/experimental/marft/config/code_env.py:CodeEnv]\n\n\n  - node_id: \"function_reward\"\n    node_type: \"COMPUTE\"\n    node_role: \"REWARD\"\n    agent_group: 1\n    dependencies:\n      - \"rollout_actor\"\n\n  - node_id: \"actor_1_old_log_prob\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"ACTOR\"\n    only_forward_compute: true\n    agent_group: 0\n    config: *actor1_config    \n    dependencies:\n      - \"function_reward\"\n\n  - node_id: \"actor_2_old_log_prob\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"ACTOR\"\n    only_forward_compute: true\n    agent_group: 1\n    config: *actor2_config    \n    dependencies:\n      - \"actor_1_old_log_prob\"\n\n  - node_id: \"reference_1_log_prob\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"REFERENCE\"\n    agent_group: 0\n    dependencies:\n      - \"actor_2_old_log_prob\"\n\n  - node_id: \"reference_2_log_prob\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"REFERENCE\"\n    agent_group: 1\n    dependencies:\n      - \"reference_1_log_prob\"\n\n  - node_id: \"critic_1_value\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"CRITIC\"\n    agent_group: 0\n    only_forward_compute: true\n    dependencies:\n      - \"reference_2_log_prob\"\n\n  - node_id: \"critic_2_value\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"CRITIC\"\n    agent_group: 1\n    only_forward_compute: true\n    dependencies:\n      - \"critic_1_value\"\n    agent_options:\n      share_instance: 0\n  \n  - node_id: \"calculate_2_advantages\"\n    node_type: \"COMPUTE\"\n    node_role: \"ADVANTAGE\"\n    agent_group: 1\n    dependencies:\n      - \"critic_2_value\"\n\n  - node_id: \"critic_1_train\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"CRITIC\"\n    agent_group: 0\n    dependencies:\n      - \"calculate_2_advantages\"\n\n  - node_id: \"critic_2_train\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"CRITIC\"\n    agent_group: 1\n    dependencies:\n      - \"critic_1_train\"\n    agent_options:\n      share_instance: 0\n\n  - node_id: \"actor_1_train\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"ACTOR\"\n    agent_group: 0\n    config: *actor1_config\n    agent_options:\n      train_cycle: 15\n    dependencies:\n      - \"critic_2_train\"\n\n  - node_id: \"actor_2_train\"\n    node_type: \"MODEL_TRAIN\"\n    node_role: \"ACTOR\"\n    agent_group: 1\n    config: *actor2_config\n    agent_options:\n      train_cycle: 15\n    dependencies:\n      - \"actor_1_train\"\n\n"
  },
  {
    "path": "examples/experimental/marft/run_qwen2_5-3b_marft.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=deepscaler\nexport ALG=gae_marft\nexport MODEL_NAME=qwen3-1.7b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen3-1.7B-Instruct\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=128\nexport PPO_MINI_BATCH_SIZE_PER_NODE=64\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=4\nexport MAX_PROMPT_LENGTH=10240\nexport MAX_RESPONSE_LENGTH=2048\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.3\nexport ROLLOUT_TP=4\nexport ROLLOUT_N=1\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\nexport PROJECT_DIR=\"$(pwd)\"\nexport DAG_WORKERFLOW=$PROJECT_DIR/examples/experimental/marft/config/workflow_marft.yaml\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.return_raw_chat=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=sglang\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=12288\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.agent.rewards_with_env=True\n    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=3\n    actor_rollout_ref.rollout.multi_turn.use_all_traj=False\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    critic.optim.lr=1e-5\n    critic.model.use_remove_padding=True\n    critic.model.path=\\$MODEL_PATH\n    critic.model.enable_gradient_checkpointing=True\n    critic.use_dynamic_bsz=False\n    critic.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    critic.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    critic.ppo_max_token_len_per_gpu=12288\n    critic.model.fsdp_config.param_offload=False\n    critic.model.fsdp_config.optimizer_offload=False\n    algorithm.kl_ctrl.kl_coef=0.001\n    algorithm.use_kl_in_reward=False\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=False\n    dag.workflow_path=\\$DAG_WORKERFLOW \n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\" \n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    \n\n    export N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\n    export NNODES=${PET_NNODES:-1}\n    export NODE_RANK=${PET_NODE_RANK:-0}\n    export MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n    export CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\n    export PROJECT_NAME=siirl_${DATASET}_${ALG}\n    export EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\n    export TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\n    export SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n    export TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\n    export PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n\n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/experimental/multiturn_server/run_qwen2_5-3b_grpo_multiturn_vllm.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=deepscaler\nexport ALG=grpo\nexport MODEL_NAME=qwen2.5-3b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen3-1.7B-Instruct\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=256\nexport PPO_MINI_BATCH_SIZE_PER_NODE=256\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=8\nexport MAX_PROMPT_LENGTH=1024\nexport MAX_RESPONSE_LENGTH=1024\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.4\nexport ROLLOUT_TP=2\nexport ROLLOUT_N=8\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\nexport PROJECT_DIR=\"$(pwd)\"\nexport CONFIG_PATH=$PROJECT_DIR/examples/multi_turn/config\nexport TOOL_CONFIG_PATH=$PROJECT_DIR/examples/multi_turn/config/tool_config/gsm8k_tool_config.yaml\nexport INTERACTION_CONFIG_PATH=$PROJECT_DIR/examples/multi_turn/config/interaction_config/gsm8k_interaction_config.yaml\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    --config-path=\\$CONFIG_PATH \n    --config-name='gsm8k_multiturn_grpo' \n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.return_raw_chat=True \n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.mode=async\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=8192\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.kl_ctrl.kl_coef=0.001\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"\\$TOOL_CONFIG_PATH\" \n    actor_rollout_ref.rollout.multi_turn.interaction_config_path=\"\\$INTERACTION_CONFIG_PATH\" \n    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=1\n    actor_rollout_ref.rollout.multi_turn.max_user_turns=1\n    actor_rollout_ref.rollout.agent.agent_name=\"tool_agent\"\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    if [ \"$HOME\" = \"{your_home_path}\" ] || [ -z \"$HOME\" ]; then echo \"ERROR: Please set 'HOME' variable.\" >&2; exit 1; fi\n    \n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/grpo_trainer/run_qwen2_5-32b-metax.sh",
    "content": " #!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=deepscaler #deepscaler\nexport ALG=grpo\nexport MODEL_NAME=qwen2.5-32b\n\n# --- Path Definitions ---\nexport HOME=/workspace/\n\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/models/Qwen2.5-32B-Instruct\n\n# Base output paths\nexport BASE_CKPT_PATH=$HOME/siirl_ckpts2\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=512 #1024  \nexport PPO_MINI_BATCH_SIZE_PER_NODE=128\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=4\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=2048\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.45\nexport ROLLOUT_TP=4\nexport ROLLOUT_N=4\nexport SAVE_FREQ=-1\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# mx gpu env\nexport MACA_PATH=/opt/maca\nexport CUCC_PATH=${MACA_PATH}/tools/cu-bridge\nexport CUDA_PATH=${CUCC_PATH}\nexport MACA_CLANG_PATH=$MACA_PATH/mxgpu_llvm/bin\nexport PATH=${CUDA_PATH}/bin:${MACA_CLANG_PATH}:${PATH}\nexport LD_LIBRARY_PATH=${MACA_PATH}/tools/cu-bridge/lib/:${MACA_PATH}/lib:${MACA_PATH}/ompi/lib:${MACA_PATH}/mxgpu_llvm/lib:${LD_LIBRARY_PATH}\nexport PYTORCH_ENABLE_SAME_RAND_A100=1\nexport MACA_SMALL_PAGESIZE_ENABLE=1\nexport SET_DEVICE_NUMA_PREFERRED=1\n# export CUDA_DEVICE_MAX_CONNECTIONS=1\nexport MCPYTORCH_DISABLE_PRINT=1\nexport MAX_JOBS=20\nexport VLLM_USE_V1=0\nunset PYTORCH_CUDA_ALLOC_CONF\nexport MCCL_ENABLE_FC=0\n# export MACA_PRIORITY_QUEUE_POLICY=0xa11\n# export MCCL_PCIE_BUFFER_MODE=0\n# export MCCL_NET_GDR_LEVEL=SYS\n\n# export MCCL_USE_FILE_TUNING=0\n# export MCCL_ALGO=Ring\n# export MCCL_DISABLE_MULTI_NODE_FABRIC=1\n# export MCCL_DISABLE_OPTIC_LINK_=1\n\nexport MCCL_MAX_NCHANNELS=8\nexport PYTHONUNBUFFERED=1\nexport MCCL_IB_HCA=mlx5\nexport MCCL_SOCKET_IFNAME=ens1\nexport GLOO_SOCKET_IFNAME=ens1\nexport SOCKET_NIC=ens1\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.auto_repeat=True\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$(($PPO_MICRO_BATCH_SIZE_PER_GPU / 1))\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.entropy_coeff=0\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=True\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=64\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.enable_chunked_prefill=True\n    actor_rollout_ref.rollout.enforce_eager=True\n    actor_rollout_ref.rollout.free_cache_engine=True\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=$(($PPO_MICRO_BATCH_SIZE_PER_GPU * 4))\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.use_kl_in_reward=False\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    # export VLLM_USE_V1=0\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            # local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            local ready_nodes=$(ray status | grep \"node_\" | wc -l)\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n\n"
  },
  {
    "path": "examples/grpo_trainer/run_qwen2_5-32b-npu.sh",
    "content": " #!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nexport LD_LIBRARY_PATH=/usr/local/Ascend/driver/:$LD_LIBRARY_PATH\nexport LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver/:$LD_LIBRARY_PATH\n\nexport DATASET=deepscaler\nexport ALG=grpo\nexport MODEL_NAME=qwen2.5-32b\nexport VLLM_USE_V1=1\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen2.5-32B-Instruct\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- GLOO Configuration ---\nexport GLOO_SOCKET_IFNAME=enp91s0np0\nexport HCCL_SOCKET_IFNAME=enp91s0np0\nexport GLOO_SOCKET_TIMEOUT=600\nexport GLOO_TCP_TIMEOUT=600\nexport HCCL_CONNECT_TIMEOUT=7200\nexport GLOO_LOG_LEVEL=INFO\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=1024\nexport PPO_MINI_BATCH_SIZE_PER_NODE=128\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=4\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=2048\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.5\nexport ROLLOUT_TP=4\nexport ROLLOUT_N=5\nexport SAVE_FREQ=-1\nexport TEST_FREQ=5\nexport TOTAL_EPOCHS=300\nexport MAX_CKPT_KEEP=5\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-16}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.auto_repeat=True\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.3 \n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.entropy_coeff=0\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=True\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=64\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.rollout.enable_chunked_prefill=True\n    actor_rollout_ref.rollout.enforce_eager=True\n    actor_rollout_ref.rollout.free_cache_engine=True\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.use_kl_in_reward=False\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n    trainer.device=npu\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n\n    ray stop --force\n    echo \"Cleaning up residual distributed processes...\"\n    pkill -f ray || true\n    pkill -f siirl.main_dag || true\n    pkill -f torchrun || true\n    pkill -f vllm || true\n    pkill -f hccl || true\n    for port in ${MASTER_PORT:-29500} ${RAY_MASTER_PORT:-6379}; do\n        for pid in $(lsof -ti :$port); do\n            kill -9 $pid || true\n        done\n    done\n    sleep 3\n    echo \"Cleanup finished.\"\n\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/grpo_trainer/run_qwen2_5-72b-npu.sh",
    "content": " #!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nexport LD_LIBRARY_PATH=/usr/local/Ascend/driver/:$LD_LIBRARY_PATH\nexport LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver/:$LD_LIBRARY_PATH\n\nexport DATASET=deepscaler\nexport ALG=grpo\nexport MODEL_NAME=qwen2.5-72b\nexport VLLM_USE_V1=1\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen2.5-32B-Instruct\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- GLOO Configuration ---\nexport GLOO_SOCKET_IFNAME=enp91s0np0\nexport HCCL_SOCKET_IFNAME=enp91s0np0\nexport GLOO_SOCKET_TIMEOUT=600\nexport GLOO_TCP_TIMEOUT=600\nexport HCCL_CONNECT_TIMEOUT=7200\nexport GLOO_LOG_LEVEL=INFO\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=512\nexport PPO_MINI_BATCH_SIZE_PER_NODE=32\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=2\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=2048\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.5\nexport ROLLOUT_TP=8\nexport ROLLOUT_N=6\nexport SAVE_FREQ=-1\nexport TEST_FREQ=5\nexport TOTAL_EPOCHS=300\nexport MAX_CKPT_KEEP=5\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-16}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=yangdian_npu_scale_up\nexport EXPERIMENT_NAME=npu_siirl_${MODEL_NAME}_${NNODES}_nodes_${ALG}_${DATASET}_experiment_$(date +%Y%m%d_%H%M%S)\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.auto_repeat=True\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.3 \n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.entropy_coeff=0\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=True\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=64\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.rollout.enable_chunked_prefill=True\n    actor_rollout_ref.rollout.enforce_eager=True\n    actor_rollout_ref.rollout.free_cache_engine=True\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.use_kl_in_reward=False\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n    trainer.device=npu\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" ; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" ; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n\n    ray stop --force\n    echo \"Cleaning up residual distributed processes...\"\n    pkill -f ray || true\n    pkill -f siirl.main_dag || true\n    pkill -f torchrun || true\n    pkill -f vllm || true\n    pkill -f hccl || true\n    for port in ${MASTER_PORT:-29500} ${RAY_MASTER_PORT:-6379}; do\n        for pid in $(lsof -ti :$port); do\n            kill -9 $pid || true\n        done\n    done\n    sleep 3\n    echo \"Cleanup finished.\"\n\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/grpo_trainer/run_qwen2_5-7b-npu-e2e_prof.sh",
    "content": " #!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=gsm8k\nexport ALG=grpo\nexport MODEL_NAME=qwen2.5-7b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen2.5-7B-Instruct\nexport PROFILE_PATH='./profile_data'\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=16\nexport PPO_MINI_BATCH_SIZE_PER_NODE=16\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=1\nexport MAX_PROMPT_LENGTH=1024\nexport MAX_RESPONSE_LENGTH=128\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.5\nexport ROLLOUT_TP=2\nexport ROLLOUT_N=5\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=5e-8\n    actor_rollout_ref.model.use_remove_padding=False\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.entropy_coeff=0\n    actor_rollout_ref.actor.kl_loss_coef=0.001\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.use_kl_in_reward=False\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n    trainer.device=npu\n    profiler.enable=True\n    profiler.save_path=\\$PROFILE_PATH\n    profiler.level='level1'\n    profiler.ranks=[0]\n    profiler.profile_steps=[3]\n    profiler.discrete=False\n    profiler.with_cpu=True\n    profiler.with_memory=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    \n\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/grpo_trainer/run_qwen2_5-7b-npu-mindspeed.sh",
    "content": " #!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nexport LD_LIBRARY_PATH=/usr/local/Ascend/driver/:$LD_LIBRARY_PATH\nexport LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver/:$LD_LIBRARY_PATH\n\nexport DATASET=deepscaler\nexport ALG=grpo\nexport MODEL_NAME=qwen2.5-7b\nexport VLLM_USE_V1=1\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen2.5-7B-Instruct\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- GLOO Configuration ---\nexport GLOO_SOCKET_TIMEOUT=600\nexport GLOO_TCP_TIMEOUT=600\nexport HCCL_CONNECT_TIMEOUT=7200\nexport GLOO_LOG_LEVEL=INFO\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=1024\nexport PPO_MINI_BATCH_SIZE_PER_NODE=256\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=4\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=2048\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.5\nexport ROLLOUT_TP=4\nexport ROLLOUT_N=5\nexport ACTOR_REF_TP=4\nexport ACTOR_REF_PP=1\nexport ACTOR_REF_CP=1\nexport ACTOR_REF_SP=False\n\nexport SAVE_FREQ=-1\nexport TEST_FREQ=5\nexport TOTAL_EPOCHS=300\nexport MAX_CKPT_KEEP=5\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.auto_repeat=True\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.3 \n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.entropy_coeff=0\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.actor.strategy=megatron\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\\$ACTOR_REF_TP\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\\$ACTOR_REF_PP\n    actor_rollout_ref.actor.megatron.context_parallel_size=\\$ACTOR_REF_CP\n    actor_rollout_ref.actor.megatron.sequence_parallel=\\$ACTOR_REF_SP\n    actor_rollout_ref.actor.megatron.use_distributed_optimizer=True\n    actor_rollout_ref.actor.megatron.param_dtype=bfloat16\n    actor_rollout_ref.actor.megatron.param_offload=True\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=False\n    actor_rollout_ref.actor.megatron.use_mbridge=False\n    +actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=True\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=16\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.rollout.enable_chunked_prefill=True\n    actor_rollout_ref.rollout.enforce_eager=True\n    actor_rollout_ref.rollout.free_cache_engine=True\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.use_kl_in_reward=False\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n    trainer.device=npu\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n    echo \"Cleaning up residual distributed processes...\"\n    pkill -f ray || true\n    pkill -f siirl.main_dag || true\n    pkill -f torchrun || true\n    pkill -f vllm || true\n    pkill -f hccl || true\n    for port in ${MASTER_PORT:-29500} ${RAY_MASTER_PORT:-6379}; do\n        for pid in $(lsof -ti :$port); do\n            kill -9 $pid || true\n        done\n    done\n    sleep 3\n    echo \"Cleanup finished.\"\n\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/grpo_trainer/run_qwen2_5-7b-npu.sh",
    "content": " #!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nexport LD_LIBRARY_PATH=/usr/local/Ascend/driver/:$LD_LIBRARY_PATH\nexport LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver/:$LD_LIBRARY_PATH\n\nexport DATASET=deepscaler\nexport ALG=grpo\nexport MODEL_NAME=qwen2.5-7b\nexport VLLM_USE_V1=1\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen2.5-7B-Instruct\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- GLOO Configuration ---\nexport GLOO_SOCKET_IFNAME=enp91s0np0\nexport HCCL_SOCKET_IFNAME=enp91s0np0\nexport GLOO_SOCKET_TIMEOUT=600\nexport GLOO_TCP_TIMEOUT=600\nexport HCCL_CONNECT_TIMEOUT=7200\nexport GLOO_LOG_LEVEL=INFO\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=1024\nexport PPO_MINI_BATCH_SIZE_PER_NODE=256\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=4\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=2048\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.5\nexport ROLLOUT_TP=4\nexport ROLLOUT_N=5\nexport SAVE_FREQ=-1\nexport TEST_FREQ=5\nexport TOTAL_EPOCHS=300\nexport MAX_CKPT_KEEP=5\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-16}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.auto_repeat=True\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.3 \n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.entropy_coeff=0\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=True\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True\n    actor_rollout_ref.actor.fsdp_config.fsdp_size=16\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.rollout.enable_chunked_prefill=True\n    actor_rollout_ref.rollout.enforce_eager=True\n    actor_rollout_ref.rollout.free_cache_engine=True\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.use_kl_in_reward=False\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n    trainer.device=npu\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n    echo \"Cleaning up residual distributed processes...\"\n    pkill -f ray || true\n    pkill -f siirl.main_dag || true\n    pkill -f torchrun || true\n    pkill -f vllm || true\n    pkill -f hccl || true\n    for port in ${MASTER_PORT:-29500} ${RAY_MASTER_PORT:-6379}; do\n        for pid in $(lsof -ti :$port); do\n            kill -9 $pid || true\n        done\n    done\n    sleep 3\n    echo \"Cleanup finished.\"\n\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/grpo_trainer/run_qwen2_5-7b.sh",
    "content": " #!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=deepscaler\nexport ALG=grpo\nexport MODEL_NAME=qwen2.5-7b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen2.5-7B-Instruct\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=512\nexport PPO_MINI_BATCH_SIZE_PER_NODE=256\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=8\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=4096\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.6\nexport ROLLOUT_TP=2\nexport ROLLOUT_N=8\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.actor.policy_drift_coeff=0.001\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=False\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=8192\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.weight_factor_in_cpgd='STD_weight'\n    algorithm.kl_ctrl.kl_coef=0.001\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    \n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/grpo_trainer/run_qwen2_5_vl-72b.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=mm_eureka\nexport ALG=grpo\nexport MODEL_NAME=qwen2.5-vl-72b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen2.5-VL-72B-Instruct\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=512\nexport PPO_MINI_BATCH_SIZE_PER_NODE=128\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=8\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=4096\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.6\nexport ROLLOUT_TP=8\nexport ROLLOUT_N=8\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=8192\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.kl_ctrl.kl_coef=0.001\n    algorithm.use_kl_in_reward=False\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.del_local_ckpt_after_load=False\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME:-bond0}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME:-bond0}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    \n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_IFNAME=bond0\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/grpo_trainer/run_qwen2_5_vl-7b-npu.sh",
    "content": " #!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=geo3k\nexport ALG=grpo\nexport MODEL_NAME=qwen2.5-vl-7b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen2.5-VL-7B-Instruct\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=512\nexport PPO_MINI_BATCH_SIZE_PER_NODE=32\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=2\nexport MAX_PROMPT_LENGTH=1024\nexport MAX_RESPONSE_LENGTH=2048\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.3\nexport ROLLOUT_TP=4\nexport ROLLOUT_N=5\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.image_key=images\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=False\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.actor.entropy_coeff=0\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.use_kl_in_reward=False\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n    trainer.device=npu\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    \n\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/grpo_trainer/run_qwen2_5_vl-7b.sh",
    "content": "\n#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=mm_eureka\nexport ALG=grpo\nexport MODEL_NAME=qwen2.5-vl-7b\n\n# --- Path Definitions ---\nexport TRAIN_DATA_PATH=/inspire/hdd/project/qianghuaxuexi/public/datasets/mm_eureka/train.parquet\nexport TEST_DATA_PATH=/inspire/hdd/project/qianghuaxuexi/public/datasets/mm_eureka/test.parquet\nexport MODEL_PATH=/inspire/hdd/project/qianghuaxuexi/public/models/Qwen2.5-VL-7B-Instruct\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=512\nexport PPO_MINI_BATCH_SIZE_PER_NODE=256\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=8\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=4096\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.6\nexport ROLLOUT_TP=2\nexport ROLLOUT_N=8\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=8192\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.kl_ctrl.kl_coef=0.001\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/grpo_trainer/run_qwen3-235b-megatron.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n# --- For config debugging\nexport HYDRA_FULL_ERROR=0\nexport SIIRL_LOG_VERBOSITY=INFO\nexport RAY_DEDUP_LOGS=1\n\n# --- Experiment and Model Definition ---\nexport DATASET=deepscaler\nexport ALG=grpo\nexport MODEL_NAME=qwen3-235b-a22b\n\n# --- Path Definitions ---\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen3-235B-A22B\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=32\nexport PPO_MINI_BATCH_SIZE_PER_NODE=32\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=4\nexport MAX_PROMPT_LENGTH=$((1024 * 2))\nexport MAX_RESPONSE_LENGTH=$((1024 * 8))\nexport MAX_MODEL_LENGTH=$((1024 * 10))\n\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.4\n\nexport ROLLOUT_TP=16\nexport ROLLOUT_N=16\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=15\nexport MAX_CKPT_KEEP=5\n\nexport ACTOR_REF_PP=8\n# export ACTOR_REF_VPP=1\nexport ACTOR_REF_TP=1\nexport ACTOR_REF_EP=8\nexport ACTOR_REF_CP=1\nexport ACTOR_REF_SP=True\n\nexport use_dynamic_bsz=False\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_zp_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_moe_megatron_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=True\n    actor_rollout_ref.model.trust_remote_code=True\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.strategy=megatron\n    actor_rollout_ref.actor.use_dynamic_bsz=\\$use_dynamic_bsz\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=\\$use_dynamic_bsz\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=\\$use_dynamic_bsz\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\\$ACTOR_REF_TP\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\\$ACTOR_REF_PP\n    # actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=\\$ACTOR_REF_VPP\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=\\$ACTOR_REF_EP\n    actor_rollout_ref.actor.megatron.context_parallel_size=\\$ACTOR_REF_CP\n    actor_rollout_ref.actor.megatron.sequence_parallel=\\$ACTOR_REF_SP\n    actor_rollout_ref.actor.megatron.use_distributed_optimizer=True\n    actor_rollout_ref.actor.megatron.param_dtype=bfloat16\n    actor_rollout_ref.actor.megatron.param_offload=True\n    actor_rollout_ref.actor.megatron.optimizer_offload=True\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=False\n    actor_rollout_ref.actor.megatron.use_mbridge=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32\n    +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform\n    actor_rollout_ref.actor.policy_drift_coeff=0.001\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=0.001\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=\\$MAX_MODEL_LENGTH\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=True\n    actor_rollout_ref.rollout.free_cache_engine=True\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=\\$ACTOR_REF_TP\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=\\$ACTOR_REF_PP\n    # actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=\\$ACTOR_REF_VPP\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=\\$ACTOR_REF_EP\n    actor_rollout_ref.ref.megatron.context_parallel_size=\\$ACTOR_REF_CP\n    actor_rollout_ref.ref.megatron.sequence_parallel=\\$ACTOR_REF_SP\n    actor_rollout_ref.ref.megatron.param_offload=True\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=False\n    algorithm.weight_factor_in_cpgd='STD_weight'\n    algorithm.kl_ctrl.kl_coef=0.001\n    trainer.critic_warmup=0\n    trainer.logger=['console','wandb']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=off\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n    dag.enable_perf=False\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    \n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\""
  },
  {
    "path": "examples/grpo_trainer/run_qwen3-235b-npu-mindspeed.sh",
    "content": " #!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nexport LD_LIBRARY_PATH=/usr/local/Ascend/driver/:$LD_LIBRARY_PATH\nexport LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver/:$LD_LIBRARY_PATH\n\nexport DATASET=deepscaler\nexport ALG=grpo\nexport MODEL_NAME=qwen3-235b\nexport VLLM_USE_V1=1\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/models/Qwen3-235B-A22B\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- GLOO Configuration ---\nexport GLOO_SOCKET_TIMEOUT=600\nexport GLOO_TCP_TIMEOUT=600\nexport HCCL_CONNECT_TIMEOUT=7200\nexport HCCL_HOST_SOCKET_PORT_RANGE='auto'\nexport GLOO_LOG_LEVEL=INFO\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=64\nexport PPO_MINI_BATCH_SIZE_PER_NODE=32\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=2\nexport MAX_PROMPT_LENGTH=1024\nexport MAX_RESPONSE_LENGTH=1024\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.7\nexport ROLLOUT_TP=32\nexport ROLLOUT_N=5\nexport ACTOR_REF_TP=4\nexport ACTOR_REF_EP=8\nexport ACTOR_REF_PP=16\n\nexport ACTOR_REF_VPP=2\nexport ACTOR_REF_CP=1\nexport ACTOR_REF_SP=True\n\n\nexport SAVE_FREQ=-1\nexport TEST_FREQ=5\nexport TOTAL_EPOCHS=300\nexport MAX_CKPT_KEEP=5\n\nexport RAY_DEDUP_LOGS=0\nexport HYDRA_FULL_ERROR=0\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-16}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=npu_${MODEL_NAME}_tp${ACTOR_REF_TP}pp${ACTOR_REF_PP}ep${ACTOR_REF_EP}_rtp${ROLLOUT_TP}_${NNODES}_nodes_${ALG}_${DATASET}_experiment_$(date +%Y%m%d_%H%M%S)\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.auto_repeat=True\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.model.trust_remote_code=True\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.3 \n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.entropy_coeff=0\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.actor.strategy=megatron\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\\$ACTOR_REF_TP\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\\$ACTOR_REF_PP\n    actor_rollout_ref.actor.megatron.context_parallel_size=\\$ACTOR_REF_CP\n    actor_rollout_ref.actor.megatron.sequence_parallel=\\$ACTOR_REF_SP\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=\\$ACTOR_REF_EP\n    actor_rollout_ref.actor.megatron.use_distributed_optimizer=True\n    actor_rollout_ref.actor.megatron.param_dtype=bfloat16\n    actor_rollout_ref.actor.megatron.param_offload=True\n    actor_rollout_ref.actor.megatron.optimizer_offload=True\n    actor_rollout_ref.actor.megatron.grad_offload=True\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=False\n    actor_rollout_ref.actor.megatron.use_mbridge=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True\n    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True\n    +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.deallocate_pipeline_outputs=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type='alltoall'\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_expert_capacity_factor=1.5\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permutation_async_comm=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.sequence_parallel=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.use_fused_swiglu=True\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=\\$ACTOR_REF_EP\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.rollout.enable_chunked_prefill=True\n    actor_rollout_ref.rollout.enforce_eager=True\n    actor_rollout_ref.rollout.free_cache_engine=True\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.megatron.param_offload=True\n    algorithm.use_kl_in_reward=False\n    trainer.critic_warmup=0\n    trainer.logger=['console','wandb']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n    trainer.device=npu\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n    echo \"Cleaning up residual distributed processes...\"\n    pkill -f ray || true\n    pkill -f siirl.main_dag || true\n    pkill -f torchrun || true\n    pkill -f vllm || true\n    pkill -f hccl || true\n    for port in ${MASTER_PORT:-29500} ${RAY_MASTER_PORT:-6379}; do\n        for pid in $(lsof -ti :$port); do\n            kill -9 $pid || true\n        done\n    done\n    sleep 3\n    echo \"Cleanup finished.\"\n\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/grpo_trainer/run_qwen3-30b-npu-mindspeed.sh",
    "content": " #!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\nexport LD_LIBRARY_PATH=/usr/local/Ascend/driver/:$LD_LIBRARY_PATH\nexport LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver/:$LD_LIBRARY_PATH\n\nexport DATASET=deepscaler\nexport ALG=grpo\nexport MODEL_NAME=qwen3-30b\nexport VLLM_USE_V1=1\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/models/Qwen3-30B-A3B\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- GLOO Configuration ---\nexport GLOO_SOCKET_TIMEOUT=600\nexport GLOO_TCP_TIMEOUT=600\nexport HCCL_CONNECT_TIMEOUT=7200\nexport HCCL_HOST_SOCKET_PORT_RANGE='auto'\nexport GLOO_LOG_LEVEL=INFO\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=64\nexport PPO_MINI_BATCH_SIZE_PER_NODE=32\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=2\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=4096\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.7\nexport ROLLOUT_TP=8\nexport ROLLOUT_N=5\nexport ACTOR_REF_TP=4\nexport ACTOR_REF_EP=8\nexport ACTOR_REF_PP=4\nexport ACTOR_REF_CP=1\nexport ACTOR_REF_SP=True\n\nexport SAVE_FREQ=-1\nexport TEST_FREQ=5\nexport TOTAL_EPOCHS=300\nexport MAX_CKPT_KEEP=5\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_tp${ACTOR_REF_TP}pp${ACTOR_REF_PP}ep${ACTOR_REF_EP}_rtp${ROLLOUT_TP}_${NNODES}_nodes_${ALG}_${DATASET}_experiment_$(date +%Y%m%d_%H%M%S)\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.auto_repeat=True\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.3 \n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.entropy_coeff=0\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.actor.strategy=megatron\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\\$ACTOR_REF_TP\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\\$ACTOR_REF_PP\n    actor_rollout_ref.actor.megatron.context_parallel_size=\\$ACTOR_REF_CP\n    actor_rollout_ref.actor.megatron.sequence_parallel=\\$ACTOR_REF_SP\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=\\$ACTOR_REF_EP\n    actor_rollout_ref.actor.megatron.use_distributed_optimizer=True\n    actor_rollout_ref.actor.megatron.param_dtype=bfloat16\n    actor_rollout_ref.actor.megatron.param_offload=True\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=False\n    actor_rollout_ref.actor.megatron.use_mbridge=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.rollout.enable_chunked_prefill=True\n    actor_rollout_ref.rollout.enforce_eager=True\n    actor_rollout_ref.rollout.free_cache_engine=True\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.use_kl_in_reward=False\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n    trainer.device=npu\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n    echo \"Cleaning up residual distributed processes...\"\n    pkill -f ray || true\n    pkill -f siirl.main_dag || true\n    pkill -f torchrun || true\n    pkill -f vllm || true\n    pkill -f hccl || true\n    for port in ${MASTER_PORT:-29500} ${RAY_MASTER_PORT:-6379}; do\n        for pid in $(lsof -ti :$port); do\n            kill -9 $pid || true\n        done\n    done\n    sleep 3\n    echo \"Cleanup finished.\"\n\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/grpo_trainer/run_qwen3-8b-megatron.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- For debugging\nexport HYDRA_FULL_ERROR=0\nexport SIIRL_LOG_VERBOSITY=INFO\n\n# --- Experiment and Model Definition ---\nexport DATASET=deepscaler\nexport ALG=grpo\nexport MODEL_NAME=qwen3-8b\n\n# --- Path Definitions ---\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen3-8B\n\n# Base output paths\nexport BASE_CKPT_PATH=$HOME/ckpts\nexport BASE_TENSORBOARD_PATH=$HOME/tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=128\nexport PPO_MINI_BATCH_SIZE_PER_NODE=16\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=8\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=4096\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.45\nexport ROLLOUT_N=8\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# ---- Key Parallelism Configuration ----\nexport ROLLOUT_TP=4\nexport ACTOR_REF_TP=4\nexport ACTOR_REF_PP=2\nexport ACTOR_REF_CP=1\nexport ACTOR_REF_SP=False\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.strategy=megatron\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\\$ACTOR_REF_TP\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\\$ACTOR_REF_PP\n    actor_rollout_ref.actor.megatron.context_parallel_size=\\$ACTOR_REF_CP\n    actor_rollout_ref.actor.megatron.sequence_parallel=\\$ACTOR_REF_SP\n    actor_rollout_ref.actor.megatron.use_distributed_optimizer=True\n    actor_rollout_ref.actor.megatron.param_dtype=bfloat16\n    actor_rollout_ref.actor.megatron.param_offload=True\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=False\n    actor_rollout_ref.actor.megatron.seed=1\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=True\n    actor_rollout_ref.rollout.free_cache_engine=True\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=\\$ACTOR_REF_TP\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=\\$ACTOR_REF_PP\n    actor_rollout_ref.ref.megatron.param_offload=False\n    trainer.logger=['console']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\""
  },
  {
    "path": "examples/grpo_trainer/run_qwen3-8b.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=deepscaler\nexport ALG=grpo\nexport MODEL_NAME=qwen3-8b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen3-8B\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=512\nexport PPO_MINI_BATCH_SIZE_PER_NODE=256\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=8\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=4096\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.5\nexport ROLLOUT_TP=2\nexport ROLLOUT_N=8\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.actor.policy_drift_coeff=0.001\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=8192\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.weight_factor_in_cpgd='STD_weight'\n    algorithm.kl_ctrl.kl_coef=0.001\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    \n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/gspo_trainer/run_qwen3-1.7b.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- For config debugging\nexport HYDRA_FULL_ERROR=0\nexport SIIRL_LOG_VERBOSITY=INFO\nexport RAY_DEDUP_LOGS=1\n\n# --- Experiment and Model Definition ---\nexport DATASET=deepscaler\nexport ALG=gspo\nexport MODEL_NAME=qwen3-1.7b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen3-1.7B\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=512\nexport PPO_MINI_BATCH_SIZE_PER_NODE=256\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=8\nexport MAX_PROMPT_LENGTH=$((1024 * 2))\nexport MAX_RESPONSE_LENGTH=$((1024 * 4))\nexport MAX_MODEL_LENGTH=$((1024 * 6))\n\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.5\n\nexport ROLLOUT_TP=1\nexport ROLLOUT_N=8\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- GSPO Specific Parameters ---\nexport LOSS_MODE=gspo\nexport ADV_ESTIMATOR=grpo\nexport CLIP_RATIO_LOW=3e-4\nexport CLIP_RATIO_HIGH=4e-4\nexport CLIP_RATIO_C=10.0\nexport LOSS_AGG_MODE=\"token-mean\"\n\n# --- KL Configuration ---\nexport USE_KL_IN_REWARD=False\nexport KL_COEF=0.001\nexport USE_KL_LOSS=True\nexport KL_LOSS_COEF=0.01\nexport KL_LOSS_TYPE=low_var_kl\n\n# --- FSDP Configuration for 1.7B ---\nexport FSDP_PARAM_OFFLOAD=False\nexport FSDP_OPTIMIZER_OFFLOAD=False\nexport REF_PARAM_OFFLOAD=True\n\n# --- Sampling Parameters ---\nexport TEMPERATURE=1.0\nexport TOP_P=1.0\nexport TOP_K=-1\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n# Uncomment the following line and set the correct network interface if needed\n# export GLOO_SOCKET_IFNAME=bond0\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ADV_ESTIMATOR\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.model.trust_remote_code=True\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    # Actor strategy and GSPO configuration\n    actor_rollout_ref.actor.strategy=fsdp\n    actor_rollout_ref.actor.policy_loss.loss_mode=\\$LOSS_MODE\n    actor_rollout_ref.actor.loss_agg_mode=\\$LOSS_AGG_MODE\n    actor_rollout_ref.actor.clip_ratio_low=\\$CLIP_RATIO_LOW\n    actor_rollout_ref.actor.clip_ratio_high=\\$CLIP_RATIO_HIGH\n    actor_rollout_ref.actor.clip_ratio_c=\\$CLIP_RATIO_C\n    actor_rollout_ref.actor.use_kl_loss=\\$USE_KL_LOSS\n    actor_rollout_ref.actor.kl_loss_coef=\\$KL_LOSS_COEF\n    actor_rollout_ref.actor.kl_loss_type=\\$KL_LOSS_TYPE\n    actor_rollout_ref.actor.policy_drift_coeff=0.001\n    # PPO configuration\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.entropy_coeff=0\n    # FSDP configuration for actor\n    actor_rollout_ref.actor.fsdp_config.param_offload=\\$FSDP_PARAM_OFFLOAD\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=\\$FSDP_OPTIMIZER_OFFLOAD\n    # Rollout configuration\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=\\$MAX_MODEL_LENGTH\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.temperature=\\$TEMPERATURE\n    actor_rollout_ref.rollout.top_p=\\$TOP_P\n    actor_rollout_ref.rollout.top_k=\\$TOP_K\n    actor_rollout_ref.rollout.calculate_log_probs=True\n    # Reference model configuration\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=\\$REF_PARAM_OFFLOAD\n    # Algorithm configuration\n    algorithm.weight_factor_in_cpgd='STD_weight'\n    algorithm.use_kl_in_reward=\\$USE_KL_IN_REWARD\n    algorithm.kl_ctrl.kl_coef=\\$KL_COEF\n    # Trainer configuration\n    trainer.critic_warmup=0\n    trainer.logger='[\"console\",\"tensorboard\",\"wandb\"]'\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n    dag.enable_perf=False\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n\n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting GSPO training command.\"\n        echo \"Command: ${TRAINING_CMD[*]}\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nif [[ \"${BASH_SOURCE[0]}\" == \"${0}\" ]]; then\n    main \"$@\"\nfi\n"
  },
  {
    "path": "examples/gspo_trainer/run_qwen3-235b-megatron.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- For config debugging\nexport HYDRA_FULL_ERROR=0\nexport SIIRL_LOG_VERBOSITY=INFO\nexport RAY_DEDUP_LOGS=1\n\n# --- Experiment and Model Definition ---\nexport DATASET=deepscaler\nexport ALG=gspo\nexport MODEL_NAME=qwen3-235b-a22b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen3-235B-A22B\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=32  # Conservative for 235B\nexport PPO_MINI_BATCH_SIZE_PER_NODE=32\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=4\nexport MAX_PROMPT_LENGTH=$((1024 * 2))\nexport MAX_RESPONSE_LENGTH=$((1024 * 8))\nexport MAX_MODEL_LENGTH=$((1024 * 10))\n\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.4  # Conservative for 235B\n\nexport ROLLOUT_TP=16  # High TP for 235B\nexport ROLLOUT_N=16\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=15\nexport MAX_CKPT_KEEP=5\n\n# --- GSPO Specific Parameters ---\nexport LOSS_MODE=gspo\nexport ADV_ESTIMATOR=grpo\nexport CLIP_RATIO_LOW=3e-4\nexport CLIP_RATIO_HIGH=4e-4\nexport CLIP_RATIO_C=10.0\nexport LOSS_AGG_MODE=\"token-mean\"\n\n# --- KL Configuration ---\nexport USE_KL_IN_REWARD=False\nexport KL_COEF=0.001\nexport USE_KL_LOSS=True\nexport KL_LOSS_COEF=0.001\nexport KL_LOSS_TYPE=low_var_kl\n\n# --- Megatron Parallelism for 235B ---\nexport ACTOR_REF_PP=8  # High pipeline parallel for 235B\nexport ACTOR_REF_TP=1  # Low tensor parallel\nexport ACTOR_REF_EP=8  # High expert parallel for MoE\nexport ACTOR_REF_CP=1  # Context parallel\nexport ACTOR_REF_SP=True  # Sequence parallel\n\nexport use_dynamic_bsz=False\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n# Uncomment the following line and set the correct network interface if needed\n# export GLOO_SOCKET_IFNAME=bond0\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_zp_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_moe_megatron_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ADV_ESTIMATOR\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=True\n    actor_rollout_ref.model.trust_remote_code=True\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.strategy=megatron\n    actor_rollout_ref.actor.use_dynamic_bsz=\\$use_dynamic_bsz\n    # GSPO specific loss configuration\n    actor_rollout_ref.actor.policy_loss.loss_mode=\\$LOSS_MODE\n    actor_rollout_ref.actor.loss_agg_mode=\\$LOSS_AGG_MODE\n    actor_rollout_ref.actor.clip_ratio_low=\\$CLIP_RATIO_LOW\n    actor_rollout_ref.actor.clip_ratio_high=\\$CLIP_RATIO_HIGH\n    actor_rollout_ref.actor.clip_ratio_c=\\$CLIP_RATIO_C\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=\\$use_dynamic_bsz\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=\\$use_dynamic_bsz\n    # Megatron configuration for actor (235B)\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\\$ACTOR_REF_TP\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\\$ACTOR_REF_PP\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=\\$ACTOR_REF_EP\n    actor_rollout_ref.actor.megatron.context_parallel_size=\\$ACTOR_REF_CP\n    actor_rollout_ref.actor.megatron.sequence_parallel=\\$ACTOR_REF_SP\n    actor_rollout_ref.actor.megatron.use_distributed_optimizer=True\n    actor_rollout_ref.actor.megatron.param_dtype=bfloat16\n    actor_rollout_ref.actor.megatron.param_offload=True\n    actor_rollout_ref.actor.megatron.optimizer_offload=True\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=False\n    actor_rollout_ref.actor.megatron.use_mbridge=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32\n    +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform\n    +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True\n    # PPO configuration\n    actor_rollout_ref.actor.policy_drift_coeff=0.001\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=\\$USE_KL_LOSS\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=\\$KL_LOSS_COEF\n    actor_rollout_ref.actor.kl_loss_type=\\$KL_LOSS_TYPE\n    # Rollout configuration (235B)\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=\\$MAX_MODEL_LENGTH\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=True\n    actor_rollout_ref.rollout.free_cache_engine=True\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    # Reference model configuration (235B)\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=\\$ACTOR_REF_TP\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=\\$ACTOR_REF_PP\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=\\$ACTOR_REF_EP\n    actor_rollout_ref.ref.megatron.context_parallel_size=\\$ACTOR_REF_CP\n    actor_rollout_ref.ref.megatron.sequence_parallel=\\$ACTOR_REF_SP\n    actor_rollout_ref.ref.megatron.param_offload=True\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=False\n    # Algorithm configuration\n    algorithm.weight_factor_in_cpgd='STD_weight'\n    algorithm.use_kl_in_reward=\\$USE_KL_IN_REWARD\n    algorithm.kl_ctrl.kl_coef=\\$KL_COEF\n    # Trainer configuration\n    trainer.critic_warmup=0\n    trainer.logger='[\"console\",\"tensorboard\"]'\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=off\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n    dag.enable_perf=False\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n\n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting GSPO training command.\"\n        echo \"Command: ${TRAINING_CMD[*]}\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nif [[ \"${BASH_SOURCE[0]}\" == \"${0}\" ]]; then\n    main \"$@\"\nfi\n"
  },
  {
    "path": "examples/gspo_trainer/run_qwen3-30b-gspo-megatron.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- For config debugging\nexport HYDRA_FULL_ERROR=0\nexport SIIRL_LOG_VERBOSITY=INFO\nexport RAY_DEDUP_LOGS=1\n\n# --- Experiment and Model Definition ---\nexport DATASET=deepscaler\nexport ALG=gspo\nexport MODEL_NAME=qwen3-30b-a3b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen3-30B-A3B-Base\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\nexport CUDA_DEVICE_MAX_CONNECTIONS=1\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=64  # Increased for 30B\nexport PPO_MINI_BATCH_SIZE_PER_NODE=32\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=2\nexport MAX_PROMPT_LENGTH=$((1024 * 2))\nexport MAX_RESPONSE_LENGTH=$((1024 * 8))\nexport MAX_MODEL_LENGTH=$((1024 * 10))\n\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.6  # Higher for 30B\n\nexport ROLLOUT_TP=4  # Reduced for 30B\nexport ROLLOUT_N=16\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=15\nexport MAX_CKPT_KEEP=5\n\n# --- GSPO Specific Parameters ---\nexport LOSS_MODE=gspo\nexport ADV_ESTIMATOR=grpo\nexport CLIP_RATIO_LOW=3e-4\nexport CLIP_RATIO_HIGH=4e-4\nexport CLIP_RATIO_C=10.0\nexport LOSS_AGG_MODE=\"token-mean\"\n\n# --- KL Configuration ---\nexport USE_KL_IN_REWARD=False\nexport KL_COEF=0.001\nexport USE_KL_LOSS=True\nexport KL_LOSS_COEF=0.001\nexport KL_LOSS_TYPE=low_var_kl\n\n# --- Megatron Parallelism for 30B ---\nexport ACTOR_REF_PP=2  # Reduced pipeline parallel\nexport ACTOR_REF_TP=4  # Tensor parallel\nexport ACTOR_REF_EP=1  # No expert parallel for 30B\nexport ACTOR_REF_CP=1  # Context parallel\nexport ACTOR_REF_SP=True  # Sequence parallel\n\nexport use_dynamic_bsz=False\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n# Uncomment the following line and set the correct network interface if needed\n# export GLOO_SOCKET_IFNAME=bond0\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ADV_ESTIMATOR\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=True\n    actor_rollout_ref.model.trust_remote_code=True\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.strategy=megatron\n    actor_rollout_ref.actor.use_dynamic_bsz=\\$use_dynamic_bsz\n    # GSPO specific loss configuration\n    actor_rollout_ref.actor.policy_loss.loss_mode=\\$LOSS_MODE\n    actor_rollout_ref.actor.loss_agg_mode=\\$LOSS_AGG_MODE\n    actor_rollout_ref.actor.clip_ratio_low=\\$CLIP_RATIO_LOW\n    actor_rollout_ref.actor.clip_ratio_high=\\$CLIP_RATIO_HIGH\n    actor_rollout_ref.actor.clip_ratio_c=\\$CLIP_RATIO_C\n    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=\\$use_dynamic_bsz\n    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=\\$use_dynamic_bsz\n    # Megatron configuration for actor\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\\$ACTOR_REF_TP\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\\$ACTOR_REF_PP\n    actor_rollout_ref.actor.megatron.expert_model_parallel_size=\\$ACTOR_REF_EP\n    actor_rollout_ref.actor.megatron.context_parallel_size=\\$ACTOR_REF_CP\n    actor_rollout_ref.actor.megatron.sequence_parallel=\\$ACTOR_REF_SP\n    actor_rollout_ref.actor.megatron.use_distributed_optimizer=True\n    actor_rollout_ref.actor.megatron.param_dtype=bfloat16\n    actor_rollout_ref.actor.megatron.param_offload=True\n    actor_rollout_ref.actor.megatron.optimizer_offload=True\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=False\n    actor_rollout_ref.actor.megatron.use_mbridge=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32\n    +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=True\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1\n    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform\n    # PPO configuration\n    actor_rollout_ref.actor.policy_drift_coeff=0.001\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=\\$USE_KL_LOSS\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=\\$KL_LOSS_COEF\n    actor_rollout_ref.actor.kl_loss_type=\\$KL_LOSS_TYPE\n    # Rollout configuration\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=\\$MAX_MODEL_LENGTH\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=True\n    actor_rollout_ref.rollout.free_cache_engine=True\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    # Reference model configuration\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=\\$ACTOR_REF_TP\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=\\$ACTOR_REF_PP\n    actor_rollout_ref.ref.megatron.expert_model_parallel_size=\\$ACTOR_REF_EP\n    actor_rollout_ref.ref.megatron.context_parallel_size=\\$ACTOR_REF_CP\n    actor_rollout_ref.ref.megatron.sequence_parallel=\\$ACTOR_REF_SP\n    actor_rollout_ref.ref.megatron.param_offload=True\n    actor_rollout_ref.ref.megatron.use_dist_checkpointing=False\n    # Algorithm configuration\n    algorithm.weight_factor_in_cpgd='STD_weight'\n    algorithm.use_kl_in_reward=\\$USE_KL_IN_REWARD\n    algorithm.kl_ctrl.kl_coef=\\$KL_COEF\n    # Trainer configuration\n    trainer.critic_warmup=0\n    trainer.logger='[\"console\",\"tensorboard\"]'\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=off\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n    dag.enable_perf=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n\n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting GSPO training command.\"\n        echo \"Command: ${TRAINING_CMD[*]}\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nif [[ \"${BASH_SOURCE[0]}\" == \"${0}\" ]]; then\n    main \"$@\"\nfi\n"
  },
  {
    "path": "examples/multi_turn/config/interaction_config/gsm8k_interaction_config.yaml",
    "content": "interaction:\n  - name: \"gsm8k\"\n    class_name: \"siirl.execution.rollout_flow.multiturn.interactions.gsm8k_interaction.Gsm8kInteraction\"\n    config: {}"
  },
  {
    "path": "examples/multi_turn/config/tool_config/gsm8k_tool_config.yaml",
    "content": "tools:\n  - class_name: \"siirl.execution.rollout_flow.multiturn.tools.gsm8k_tool.Gsm8kTool\"\n    config: \n      type: native\n    tool_schema:\n      type: \"function\"\n      function:\n        name: \"calc_gsm8k_reward\"\n        description: \"A tool for calculating the reward of gsm8k. (1.0 if parsed answer is correct, 0.0 if parsed answer is incorrect or not correctly parsed)\"\n        parameters:\n          type: \"object\"\n          properties:\n            answer:\n              type: \"string\"\n              description: \"The model's answer to the GSM8K math problem, must be a digits\"\n          required: [\"answer\"]\n"
  },
  {
    "path": "examples/multi_turn/gsm8k/run_qwen2_5-3b_grpo_multiturn_sglang.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=deepscaler\nexport ALG=grpo\nexport MODEL_NAME=qwen2.5-3b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen2.5-3B-Instruct\n\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=256\nexport PPO_MINI_BATCH_SIZE_PER_NODE=256\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=8\nexport MAX_PROMPT_LENGTH=1024\nexport MAX_RESPONSE_LENGTH=1024\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.4\nexport ROLLOUT_TP=2\nexport ROLLOUT_N=8\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\nexport PROJECT_DIR=\"$(pwd)\"\nexport CONFIG_PATH=$PROJECT_DIR/examples/multi_turn/config\nexport TOOL_CONFIG_PATH=$PROJECT_DIR/examples/multi_turn/config/tool_config/gsm8k_tool_config.yaml\nexport INTERACTION_CONFIG_PATH=$PROJECT_DIR/examples/multi_turn/config/interaction_config/gsm8k_interaction_config.yaml\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.return_raw_chat=True \n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=sglang\n    actor_rollout_ref.rollout.mode=sync\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=8192\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    algorithm.kl_ctrl.kl_coef=0.001\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n    actor_rollout_ref.rollout.multi_turn.enable=True\n    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=5\n    actor_rollout_ref.rollout.multi_turn.tool_config_path=\"\\$TOOL_CONFIG_PATH\" \n    actor_rollout_ref.rollout.multi_turn.interaction_config_path=\"\\$INTERACTION_CONFIG_PATH\" \n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    if [ \"$HOME\" = \"{your_home_path}\" ] || [ -z \"$HOME\" ]; then echo \"ERROR: Please set 'HOME' variable.\" >&2; exit 1; fi\n    \n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/ppo_trainer/run_qwen2_5-72b.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=deepscaler\nexport ALG=gae\nexport MODEL_NAME=qwen2.5-72b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen2.5-72B-Instruct\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=512\nexport PPO_MINI_BATCH_SIZE_PER_NODE=128\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=8\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=4096\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.6\nexport ROLLOUT_TP=8\nexport ROLLOUT_N=1\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=8192\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    critic.optim.lr=1e-5\n    critic.model.use_remove_padding=True\n    critic.model.path=\\$MODEL_PATH\n    critic.model.enable_gradient_checkpointing=True\n    critic.use_dynamic_bsz=False\n    critic.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    critic.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    critic.ppo_max_token_len_per_gpu=98304\n    critic.model.fsdp_config.param_offload=False\n    critic.model.fsdp_config.optimizer_offload=False\n    algorithm.kl_ctrl.kl_coef=0.001\n    algorithm.use_kl_in_reward=False\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.del_local_ckpt_after_load=False\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\" --block\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    \n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_IFNAME=bond0\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n\n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 2\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "examples/ppo_trainer/run_qwen3-8b-megatron.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- For debugging\nexport HYDRA_FULL_ERROR=0\nexport SIIRL_LOG_VERBOSITY=INFO\n\n# --- Experiment and Model Definition ---\nexport DATASET=deepscaler\nexport ALG=gae\nexport MODEL_NAME=qwen3-8b\n\n# --- Path Definitions ---\nexport TRAIN_DATA_PATH=$HOME/data/dataset/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/dataset/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen3-8B\n\n# Base output paths\nexport BASE_CKPT_PATH=$HOME/ckpts\nexport BASE_TENSORBOARD_PATH=$HOME/tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=1024\nexport PPO_MINI_BATCH_SIZE_PER_NODE=256\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=8\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=4096\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.45\nexport ROLLOUT_N=1\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\nexport ACTOR_REF_CRITIC_TP=2\nexport ACTOR_REF_CRITIC_PP=2\nexport ACTOR_REF_CRITIC_CP=1\nexport ACTOR_REF_CRITIC_SP=False\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.strategy=megatron\n    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=\\$ACTOR_REF_CRITIC_TP\n    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=\\$ACTOR_REF_CRITIC_PP\n    actor_rollout_ref.actor.megatron.context_parallel_size=\\$ACTOR_REF_CRITIC_CP\n    actor_rollout_ref.actor.megatron.sequence_parallel=\\$ACTOR_REF_CRITIC_SP\n    actor_rollout_ref.actor.megatron.use_distributed_optimizer=True\n    actor_rollout_ref.actor.megatron.param_dtype=bfloat16\n    actor_rollout_ref.actor.megatron.param_offload=True\n    actor_rollout_ref.actor.megatron.use_dist_checkpointing=False\n    actor_rollout_ref.actor.megatron.seed=1\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ACTOR_REF_CRITIC_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=8192\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=True\n    actor_rollout_ref.rollout.free_cache_engine=True\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.strategy=megatron\n    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=\\$ACTOR_REF_CRITIC_TP\n    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=\\$ACTOR_REF_CRITIC_PP\n    actor_rollout_ref.ref.megatron.context_parallel_size=\\$ACTOR_REF_CRITIC_CP\n    actor_rollout_ref.ref.megatron.sequence_parallel=\\$ACTOR_REF_CRITIC_SP\n    actor_rollout_ref.ref.megatron.param_offload=False\n    critic.optim.lr=1e-5\n    critic.model.use_remove_padding=True\n    critic.model.path=\\$MODEL_PATH\n    critic.model.enable_gradient_checkpointing=True\n    critic.use_dynamic_bsz=False\n    critic.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    critic.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    critic.ppo_max_token_len_per_gpu=98304\n    critic.strategy=megatron\n    critic.megatron.tensor_model_parallel_size=\\$ACTOR_REF_CRITIC_TP\n    critic.megatron.pipeline_model_parallel_size=\\$ACTOR_REF_CRITIC_PP\n    critic.megatron.context_parallel_size=\\$ACTOR_REF_CRITIC_CP\n    critic.megatron.sequence_parallel=\\$ACTOR_REF_CRITIC_SP\n    critic.megatron.param_offload=True\n    critic.megatron.optimizer_offload=True\n    algorithm.kl_ctrl.kl_coef=0.001\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\"\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n\n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n    \n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 5\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\""
  },
  {
    "path": "examples/ppo_trainer/run_qwen3-8b.sh",
    "content": "#!/usr/bin/env bash\n# ===================================================================================\n# ===                       USER CONFIGURATION SECTION                            ===\n# ===================================================================================\n\n# --- Experiment and Model Definition ---\nexport DATASET=deepscaler\nexport ALG=gae\nexport MODEL_NAME=qwen3-8b\n\n# --- Path Definitions ---\nexport HOME={your_home_path}\nexport TRAIN_DATA_PATH=$HOME/data/datasets/$DATASET/train.parquet\nexport TEST_DATA_PATH=$HOME/data/datasets/$DATASET/test.parquet\nexport MODEL_PATH=$HOME/data/models/Qwen3-8B\n\n# Base output paths\nexport BASE_CKPT_PATH=ckpts\nexport BASE_TENSORBOARD_PATH=tensorboard\n\n# --- Key Training Hyperparameters ---\nexport TRAIN_BATCH_SIZE_PER_NODE=512\nexport PPO_MINI_BATCH_SIZE_PER_NODE=256\nexport PPO_MICRO_BATCH_SIZE_PER_GPU=8\nexport MAX_PROMPT_LENGTH=2048\nexport MAX_RESPONSE_LENGTH=4096\nexport ROLLOUT_GPU_MEMORY_UTILIZATION=0.6\nexport ROLLOUT_TP=1\nexport ROLLOUT_N=1\nexport SAVE_FREQ=30\nexport TEST_FREQ=10\nexport TOTAL_EPOCHS=30\nexport MAX_CKPT_KEEP=5\n\n# --- Multi-node (Multi-machine) distributed training environments ---\n\n# Uncomment the following line and set the correct network interface if needed for distributed backend\n# export GLOO_SOCKET_IFNAME=bond0  # Modify as needed\n\n# --- Distributed Training & Infrastructure ---\nexport N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8}\nexport NNODES=${PET_NNODES:-1}\nexport NODE_RANK=${PET_NODE_RANK:-0}\nexport MASTER_ADDR=${MASTER_ADDR:-localhost}\n\n# --- Output Paths and Experiment Naming ---\nexport CKPT_PATH=${BASE_CKPT_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}nodes\nexport PROJECT_NAME=siirl_${DATASET}_${ALG}\nexport EXPERIMENT_NAME=siirl_${MODEL_NAME}_${ALG}_${DATASET}_experiment\nexport TENSORBOARD_DIR=${BASE_TENSORBOARD_PATH}/${MODEL_NAME}_${ALG}_${DATASET}_hybrid_tensorboard/dlc_${NNODES}_$timestamp\nexport SIIRL_LOGGING_FILENAME=${MODEL_NAME}_${ALG}_${DATASET}_hybrid_${NNODES}_$timestamp\n\n# --- Calculated Global Hyperparameters ---\nexport TRAIN_BATCH_SIZE=$(($TRAIN_BATCH_SIZE_PER_NODE * $NNODES))\nexport PPO_MINI_BATCH_SIZE=$(($PPO_MINI_BATCH_SIZE_PER_NODE * $NNODES))\n\n# --- Define the Training Command and its Arguments ---\nTRAINING_CMD=(\n    python3 -m siirl.main_dag\n    algorithm.adv_estimator=\\$ALG\n    data.train_files=\\$TRAIN_DATA_PATH\n    data.val_files=\\$TEST_DATA_PATH\n    data.train_batch_size=\\$TRAIN_BATCH_SIZE\n    data.max_prompt_length=\\$MAX_PROMPT_LENGTH\n    data.max_response_length=\\$MAX_RESPONSE_LENGTH\n    data.filter_overlong_prompts=True\n    data.truncation='error'\n    data.shuffle=False\n    actor_rollout_ref.model.path=\\$MODEL_PATH\n    actor_rollout_ref.actor.optim.lr=1e-6\n    actor_rollout_ref.model.use_remove_padding=True\n    actor_rollout_ref.model.use_fused_kernels=False\n    actor_rollout_ref.actor.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.actor.use_kl_loss=True\n    actor_rollout_ref.actor.grad_clip=0.5\n    actor_rollout_ref.actor.clip_ratio=0.2\n    actor_rollout_ref.actor.kl_loss_coef=0.01\n    actor_rollout_ref.actor.kl_loss_type=low_var_kl\n    actor_rollout_ref.model.enable_gradient_checkpointing=True\n    actor_rollout_ref.actor.fsdp_config.param_offload=False\n    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False\n    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.rollout.tensor_model_parallel_size=\\$ROLLOUT_TP\n    actor_rollout_ref.rollout.name=vllm\n    actor_rollout_ref.rollout.gpu_memory_utilization=\\$ROLLOUT_GPU_MEMORY_UTILIZATION\n    actor_rollout_ref.rollout.max_model_len=8192\n    actor_rollout_ref.rollout.enable_chunked_prefill=False\n    actor_rollout_ref.rollout.enforce_eager=False\n    actor_rollout_ref.rollout.free_cache_engine=False\n    actor_rollout_ref.rollout.n=\\$ROLLOUT_N\n    actor_rollout_ref.rollout.prompt_length=\\$MAX_PROMPT_LENGTH  \n    actor_rollout_ref.rollout.response_length=\\$MAX_RESPONSE_LENGTH\n    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    actor_rollout_ref.ref.fsdp_config.param_offload=True\n    critic.optim.lr=1e-5\n    critic.model.use_remove_padding=True\n    critic.model.path=\\$MODEL_PATH\n    critic.model.enable_gradient_checkpointing=True\n    critic.use_dynamic_bsz=False\n    critic.ppo_micro_batch_size_per_gpu=\\$PPO_MICRO_BATCH_SIZE_PER_GPU\n    critic.ppo_mini_batch_size=\\$PPO_MINI_BATCH_SIZE\n    critic.ppo_max_token_len_per_gpu=98304\n    critic.model.fsdp_config.param_offload=False\n    critic.model.fsdp_config.optimizer_offload=False\n    algorithm.kl_ctrl.kl_coef=0.001\n    trainer.critic_warmup=0\n    trainer.logger=['console','tensorboard']\n    trainer.project_name=\\$PROJECT_NAME\n    trainer.experiment_name=\\$EXPERIMENT_NAME\n    trainer.n_gpus_per_node=\\$N_GPUS_PER_NODE\n    trainer.nnodes=\\$NNODES\n    trainer.save_freq=\\$SAVE_FREQ\n    trainer.test_freq=\\$TEST_FREQ\n    trainer.total_epochs=\\$TOTAL_EPOCHS\n    trainer.resume_mode=auto\n    trainer.max_actor_ckpt_to_keep=\\$MAX_CKPT_KEEP\n    trainer.default_local_dir=\\$CKPT_PATH\n    trainer.val_before_train=True\n)\n\n# ===================================================================================\n# ===                  MAIN EXECUTION LOGIC & INFRASTRUCTURE                      ===\n# ===================================================================================\n\n# --- Boilerplate Setup ---\nset -e\nset -o pipefail\nset -x\n\n# --- Infrastructure & Boilerplate Functions ---\nstart_ray_cluster() {\n    local RAY_HEAD_WAIT_TIMEOUT=600\n    export RAY_RAYLET_NODE_MANAGER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_GCS_SERVER_CONFIG_NIC_NAME=${INTERFACE_NAME}\n    export RAY_RUNTIME_ENV_AGENT_CREATION_TIMEOUT_S=1200\n    export RAY_GCS_RPC_CLIENT_CONNECT_TIMEOUT_S=120\n\n    local ray_start_common_opts=(\n        --num-gpus \"$N_GPUS_PER_NODE\"\n        --object-store-memory 100000000000\n        --memory 100000000000\n    )\n\n    if [ \"$NNODES\" -gt 1 ]; then\n        if [ \"$NODE_RANK\" = \"0\" ]; then\n            echo \"INFO: Starting Ray head node on $(hostname)...\"\n            export RAY_ADDRESS=\"$RAY_MASTER_ADDR:$RAY_MASTER_PORT\"\n            ray start --head --port=\"$RAY_MASTER_PORT\" --dashboard-port=\"$RAY_DASHBOARD_PORT\" \"${ray_start_common_opts[@]}\" --system-config='{\"gcs_server_request_timeout_seconds\": 60, \"gcs_rpc_server_reconnect_timeout_s\": 60}'\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$RAY_ADDRESS\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head node. Exiting.\" >&2; ray stop --force; exit 1; fi\n                echo \"Head node not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head node is healthy.\"\n        else\n            local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n            echo \"INFO: Worker node $(hostname) waiting for head at $head_node_address...\"\n            local start_time=$(date +%s)\n            while ! ray health-check --address \"$head_node_address\" &>/dev/null; do\n                if [ \"$(( $(date +%s) - start_time ))\" -ge \"$RAY_HEAD_WAIT_TIMEOUT\" ]; then echo \"ERROR: Timed out waiting for head. Exiting.\" >&2; exit 1; fi\n                echo \"Head not healthy yet. Retrying in 5s...\"\n                sleep 5\n            done\n            echo \"INFO: Head is healthy. Worker starting...\"\n            ray start --address=\"$head_node_address\" \"${ray_start_common_opts[@]}\" --block\n        fi\n    else\n        echo \"INFO: Starting Ray in single-node mode...\"\n        ray start --head \"${ray_start_common_opts[@]}\"\n    fi\n}\n\n# --- Main Execution Function ---\nmain() {\n    local timestamp=$(date +\"%Y%m%d_%H%M%S\")\n    ray stop --force\n\n    \n\n    export VLLM_USE_V1=1\n    export GLOO_SOCKET_TIMEOUT=600\n    export GLOO_TCP_TIMEOUT=600\n    export GLOO_LOG_LEVEL=DEBUG\n    export RAY_MASTER_PORT=${RAY_MASTER_PORT:-6379}\n    export RAY_DASHBOARD_PORT=${RAY_DASHBOARD_PORT:-8265}\n    export RAY_MASTER_ADDR=$MASTER_ADDR\n\n    start_ray_cluster\n\n    if [ \"$NNODES\" -gt 1 ] && [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"Waiting for all $NNODES nodes to join...\"\n        local TIMEOUT=600; local start_time=$(date +%s)\n        while true; do\n            if [ \"$(( $(date +%s) - start_time ))\" -ge \"$TIMEOUT\" ]; then echo \"Error: Timeout waiting for nodes.\" >&2; exit 1; fi\n            local ready_nodes=$(ray list nodes --format=json | python3 -c \"import sys, json; print(len(json.load(sys.stdin)))\")\n            if [ \"$ready_nodes\" -ge \"$NNODES\" ]; then break; fi\n            echo \"Waiting... ($ready_nodes / $NNODES nodes ready)\"\n            sleep 2\n        done\n        echo \"All $NNODES nodes have joined.\"\n    fi\n\n    if [ \"$NODE_RANK\" = \"0\" ]; then\n        echo \"INFO [RANK 0]: Starting main training command.\"\n        eval \"${TRAINING_CMD[@]}\" \"$@\"\n        echo \"INFO [RANK 0]: Training finished.\"\n        sleep 30; ray stop --force >/dev/null 2>&1\n    elif [ \"$NNODES\" -gt 1 ]; then\n        local head_node_address=\"$MASTER_ADDR:$RAY_MASTER_PORT\"\n        echo \"INFO [RANK $NODE_RANK]: Worker active. Monitoring head node at $head_node_address.\"\n        while ray health-check --address \"$head_node_address\" &>/dev/null; do sleep 15; done\n        echo \"INFO [RANK $NODE_RANK]: Head node down. Exiting.\"\n    fi\n\n    echo \"INFO: Script finished on rank $NODE_RANK.\"\n}\n\n# --- Script Entrypoint ---\nmain \"$@\"\n"
  },
  {
    "path": "pyproject.toml",
    "content": "# ===================================================================\n# pyproject.toml for siirl\n#\n# PEP 621-compliant configuration file for project metadata,\n# build system, and tool configurations. This file works in\n# conjunction with a minimal setup.py shim.\n# ===================================================================\n\n\n# -------------------------------\n# Build System\n# -------------------------------\n[build-system]\nrequires = [\n    \"setuptools>=61.0\",\n    \"setuptools_scm[toml]>=6.2\",\n    \"wheel\"\n]\nbuild-backend = \"setuptools.build_meta\"\n\n\n# -------------------------------\n# Project Metadata (PEP 621)\n# -------------------------------\n[project]\nname = \"siirl\"\n# Version is loaded dynamically from a file. See [tool.setuptools.dynamic].\ndynamic = [\"version\"]\n\ndescription = \"siirl: A Decentralized Multi-Agent Reinforcement Learning Framework\"\nlicense = {file = \"LICENSE\"}\nreadme = {file = \"README.md\", content-type = \"text/markdown\"}\nrequires-python = \">=3.8\"\n\n# --- Author & URL Information ---\nauthors = [\n  { name = \"Shanghai Innovation Institute - AI Infra Team\", email = \"llm19900326@gmail.com\" },\n]\n\n# --- Project Discovery ---\nkeywords = [\"reinforcement learning\", \"multi-agent\", \"decentralized\", \"rl\", \"ai\"]\n\n# Standardized classifiers from https://pypi.org/classifiers/\nclassifiers = [\n    \"Development Status :: 4 - Beta\",\n    \"Intended Audience :: Developers\",\n    \"Intended Audience :: Science/Research\",\n    \"License :: OSI Approved :: Apache Software License\",\n    \"Programming Language :: Python :: 3\",\n    \"Programming Language :: Python :: 3.10\",\n    \"Programming Language :: Python :: 3.11\",\n    \"Programming Language :: Python :: 3.12\",\n    \"Operating System :: OS Independent\",\n    \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n]\n\n\n# --- Dependencies ---\n# Runtime dependencies required by the project.\ndependencies = [\n    \"accelerate\",\n    \"codetiming\",\n    \"datasets>=4.0.0\",\n    \"dill\",\n    \"hydra-core\",\n    \"numpy\",\n    \"pandas\",\n    \"peft\",\n    \"pyarrow>=19.0.0\",\n    \"pybind11\",\n    \"pylatexenc\",\n    \"ray[default]>=2.47.1\",\n    \"torchdata\",\n    \"tensordict>=0.8.0,<=0.9.1,!=0.9.0\",\n    \"wandb\",\n    \"tensorboard\",\n    \"mathruler\",\n    \"math_verify\",\n    \"timm\",\n    \"imageio\",\n    \"loguru\",\n    \"packaging>=20.0\",\n    \"dacite\",\n    \"qwen_vl_utils\",\n    \"scipy\",\n    \"fastapi\",\n    \"transformers\",\n    \"math-verify\",\n    \"vllm>=0.8.5.post1\",\n]\n\n\n# --- Optional Dependencies ---\n# Corresponds to 'extras_require' in setup.py.\n# Install with: pip install \"siirl[gpu]\"\n[project.optional-dependencies]\n# For core development and releasing\ndev = [\n    \"ruff\",\n    \"pytest\",\n    \"build\",\n    \"twine\",\n    \"pre-commit\",\n    \"py-spy\",\n]\ntest = [\n    \"pytest\",\n    \"pre-commit\",\n    \"py-spy\",\n    \"pyext\",\n]\n\ngeo = [\"mathruler\"]\ngpu = [\"liger-kernel\", \"flash-attn\"]\nsglang = [\n    \"tensordict>=0.8.0,<=0.9.1,!=0.9.0\",\n    \"sglang[all]>=0.4.6.post5\",\n    \"torch-memory-saver>=0.0.5\",\n    \"torch>=2.6.0\",\n]\n\n\n# --- Project URLs ---\n# This table should only contain string key-value pairs for URLs.\n[project.urls]\n\"Homepage\" = \"https://github.com/sii-research/siiRL\"\n\"Bug Tracker\" = \"https://github.com/sii-research/siiRL/issues\"\n\"Repository\" = \"https://github.com/sii-research/siiRL\"\n\n\n# -------------------------------\n# Tool: Ruff (Linting)\n# -------------------------------\n[tool.ruff]\nline-length = 120 # TODO: Reduce this to a more reasonable value\n\n[tool.ruff.lint]\nisort = {known-first-party = [\"siirl\"]}\nselect = [ \"E\", \"F\", \"UP\", \"B\", \"I\", \"G\" ]\nignore = [ \"F405\", \"F403\", \"E731\", \"B007\", \"UP032\", \"UP007\", \"G004\" ]\n\n\n# -------------------------------\n# Tool: Setuptools\n# -------------------------------\n[tool.setuptools]\ninclude-package-data = true\n# Modern equivalent of find_packages()\npackages = { find = {} }\n\n[tool.setuptools_scm]\nwrite_to = \"siirl/_version.py\"\n\n[tool.setuptools.package-dir]\n\"\" = \".\"\n\n[tool.setuptools.package-data]\nsiirl = [\n  \"client/config/*.yaml\"\n]\n"
  },
  {
    "path": "requirements-npu.txt",
    "content": "accelerate\ncodetiming\ndatasets>=4.0.0\ndill\nhydra-core\nnumpy\npandas\npeft\npyarrow>=19.0.0\npybind11\npylatexenc\nray[default]>=2.47.1\ntorchdata\ntensordict>=0.8.0,<=0.9.1,!=0.9.0,\ntransformers\nwandb\ntensorboard\nmathruler\nmath_verify\ntimm\nimageio\nloguru\npackaging>=20.0\ndacite\nqwen_vl_utils\nscipy\nfastapi\ntorch_npu==2.5.1\nvllm>=0.9.1\nvllm_ascend>=0.9.1\nmbridge==0.13.0\n"
  },
  {
    "path": "requirements.txt",
    "content": "accelerate\ncodetiming\ndatasets>=4.0.0\ndill\nhydra-core\nnumpy\npandas\npeft\npyarrow>=19.0.0\npybind11\npylatexenc\nray[default]>=2.47.1\ntorchdata\ntensordict>=0.8.0,<=0.9.1,!=0.9.0,\ntransformers\nwandb\ntensorboard\nmathruler\nmath_verify\ntimm\nimageio\nloguru\npackaging>=20.0\ndacite\nqwen_vl_utils\nscipy\nfastapi\nvllm>=0.8.5.post1\n"
  },
  {
    "path": "setup.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom setuptools import setup\n\n# This is a \"shim\" setup.py file that delegates all configuration\n# to the pyproject.toml file. This is the recommended approach for\n# projects that need to maintain a setup.py for compatibility while\n# adopting modern packaging standards.\n#\n# All metadata, dependencies, and package data are defined in pyproject.toml.\n# This setup() call is intentionally left empty.\nsetup()"
  },
  {
    "path": "siirl/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nfrom importlib.metadata import version\nfrom packaging.version import parse as parse_version\nfrom importlib.metadata import PackageNotFoundError\n\nfrom siirl.utils.extras.device import is_npu_available\nfrom siirl.utils.logger.logging_utils import set_basic_config\n\n\nset_basic_config()\n\n\n__all__ = []\n\nif os.getenv(\"SIIRL_USE_MODELSCOPE\", \"False\").lower() == \"true\":\n    import importlib\n\n    if importlib.util.find_spec(\"modelscope\") is None:\n        raise ImportError(\"You are using the modelscope hub, please install modelscope by `pip install modelscope -U`\")\n    # Patch hub to download models from modelscope to speed up.\n    from modelscope.utils.hf_util import patch_hub\n\n    patch_hub()\n\nif is_npu_available:\n    from .models.transformers import npu_patch as npu_patch\n\n    package_name = \"transformers\"\n    required_version_spec = \"4.52.4\"\n    try:\n        installed_version = version(package_name)\n        installed = parse_version(installed_version)\n        required = parse_version(required_version_spec)\n\n        if not installed >= required:\n            raise ValueError(f\"{package_name} version >= {required_version_spec} is required on ASCEND NPU, current version is {installed}.\")\n    except PackageNotFoundError:\n        raise ImportError(f\"package {package_name} is not installed, please run pip install {package_name}=={required_version_spec}\")\n"
  },
  {
    "path": "siirl/dag_worker/__init__.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/dag_worker/checkpoint_manager.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Checkpoint save/load operations for distributed training.\"\"\"\n\nimport os\nimport torch\nimport torch.distributed as dist\nfrom typing import Dict, Optional, Any\nfrom loguru import logger\n\nfrom siirl.execution.dag.node import NodeRole, NodeType\nfrom siirl.params import SiiRLArguments\nfrom siirl.dag_worker.constants import DAGConstants\nfrom siirl.dag_worker.dag_utils import generate_node_worker_key\nfrom siirl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path\n\n\nclass CheckpointManager:\n    \"\"\"Manages distributed checkpoint save/load with atomic commits.\"\"\"\n\n    def __init__(\n        self,\n        config: SiiRLArguments,\n        rank: int,\n        gather_group: dist.ProcessGroup,\n        workers: Dict[str, Any],\n        taskgraph: Any,\n        dataloader: Any,\n        first_rollout_node: Any,\n        get_node_dp_info_fn: callable\n    ):\n        self.config = config\n        self.rank = rank\n        self.gather_group = gather_group\n        self.workers = workers\n        self.taskgraph = taskgraph\n        self.dataloader = dataloader\n        self.first_rollout_node = first_rollout_node\n        self._get_node_dp_info = get_node_dp_info_fn\n\n    def save_checkpoint(self, global_steps: int) -> None:\n        \"\"\"Save checkpoint atomically across all ranks.\"\"\"\n        step_dir = os.path.join(self.config.trainer.default_local_dir, f\"global_step_{global_steps}\")\n        os.makedirs(step_dir, exist_ok=True)\n        dist.barrier(self.gather_group)\n\n        logger.info(f\"Rank {self.rank}: Saving checkpoint for global_step {global_steps} to {step_dir}\")\n\n        self._save_model_states(global_steps, step_dir)\n        self._save_dataloader_state(step_dir)\n\n        logger.debug(f\"Rank {self.rank}: All data saved. Waiting at barrier before committing checkpoint.\")\n        dist.barrier(self.gather_group)\n\n        if self.rank == 0:\n            self._commit_checkpoint(global_steps)\n\n        dist.barrier(self.gather_group)\n        logger.info(f\"Rank {self.rank}: Finished saving and committing checkpoint for step {global_steps}.\")\n\n    def _save_model_states(self, global_steps: int, step_dir: str) -> None:\n        \"\"\"Save model states for all trainable nodes.\"\"\"\n        saved_worker_keys = set()\n\n        for node in self.taskgraph.nodes.values():\n            if node.node_type != NodeType.MODEL_TRAIN:\n                continue\n            if node.node_role not in [NodeRole.ACTOR, NodeRole.CRITIC]:\n                continue\n\n            node_worker_key = generate_node_worker_key(node)\n\n            if node_worker_key in saved_worker_keys:\n                continue\n\n            worker = self.workers[node_worker_key]\n\n            sub_dir_name = f\"{node.node_role.name.lower()}_agent_{node.agent_group}\"\n            checkpoint_path = os.path.join(step_dir, sub_dir_name)\n\n            role_name_for_config = node.node_role.name.lower()\n            max_ckpt_keep = getattr(\n                self.config.trainer,\n                f\"max_{role_name_for_config}_ckpt_to_keep\",\n                10\n            )\n\n            worker.save_checkpoint(\n                local_path=checkpoint_path,\n                global_step=global_steps,\n                max_ckpt_to_keep=max_ckpt_keep\n            )\n            saved_worker_keys.add(node_worker_key)\n            logger.debug(\n                f\"Rank {self.rank}: Saved {node.node_role.name} checkpoint for agent {node.agent_group} \"\n                f\"to {checkpoint_path}\"\n            )\n\n    def _save_dataloader_state(self, step_dir: str) -> None:\n        \"\"\"Save dataloader state (only TP rank 0 and PP rank 0 per DP group).\"\"\"\n        _, dp_rank, tp_rank, _, pp_rank, _ = self._get_node_dp_info(self.first_rollout_node)\n\n        if tp_rank == 0 and pp_rank == 0:\n            dataloader_path = os.path.join(step_dir, f\"data_dp_rank_{dp_rank}.pt\")\n            dataloader_state = self.dataloader.state_dict()\n            torch.save(dataloader_state, dataloader_path)\n            logger.debug(\n                f\"Rank {self.rank} (DP_Rank {dp_rank}, TP_Rank {tp_rank}, PP_Rank {pp_rank}): \"\n                f\"Saved dataloader state to {dataloader_path}\"\n            )\n\n    def _commit_checkpoint(self, global_steps: int) -> None:\n        \"\"\"Atomically commit checkpoint by writing tracker file (rank 0 only).\"\"\"\n        tracker_file = os.path.join(\n            self.config.trainer.default_local_dir,\n            \"latest_checkpointed_iteration.txt\"\n        )\n        with open(tracker_file, \"w\") as f:\n            f.write(str(global_steps))\n        logger.info(f\"Rank 0: Checkpoint for step {global_steps} successfully committed.\")\n\n    def load_checkpoint(self) -> int:\n        \"\"\"Load checkpoint and return global step to resume from.\"\"\"\n        if self.config.trainer.resume_mode == \"disable\":\n            if self.rank == 0:\n                logger.info(\"Checkpoint loading is disabled. Starting from scratch.\")\n            return 0\n\n        checkpoint_path = self._determine_checkpoint_path()\n\n        checkpoint_path_container = [checkpoint_path]\n        dist.broadcast_object_list(checkpoint_path_container, src=0)\n        global_step_folder = checkpoint_path_container[0]\n\n        if global_step_folder is None:\n            if self.rank == 0:\n                logger.info(\"No valid checkpoint to load. Training will start from step 0.\")\n            dist.barrier(self.gather_group)\n            return 0\n\n        try:\n            global_steps = int(os.path.basename(global_step_folder).split(\"global_step_\")[-1])\n            logger.info(\n                f\"Rank {self.rank}: Resuming from checkpoint. \"\n                f\"Setting global_steps to {global_steps}.\"\n            )\n        except (ValueError, IndexError) as e:\n            raise ValueError(\n                f\"Could not parse global step from checkpoint path: {global_step_folder}\"\n            ) from e\n\n        self._load_model_states(global_step_folder)\n        self._load_dataloader_state(global_step_folder)\n\n        dist.barrier(self.gather_group)\n        logger.info(f\"Rank {self.rank}: Finished loading all checkpoint components.\")\n\n        return global_steps\n\n    def _determine_checkpoint_path(self) -> Optional[str]:\n        \"\"\"Determine checkpoint path (rank 0 only).\"\"\"\n        if self.rank != 0:\n            return None\n\n        checkpoint_dir = self.config.trainer.default_local_dir\n        resume_from_path = self.config.trainer.resume_from_path\n        path_to_load = None\n\n        if self.config.trainer.resume_mode == \"auto\":\n            latest_path = find_latest_ckpt_path(checkpoint_dir)\n            if latest_path:\n                logger.info(f\"Rank 0: Auto-found latest checkpoint at {latest_path}\")\n                path_to_load = latest_path\n        elif self.config.trainer.resume_mode == \"resume_path\" and resume_from_path:\n            logger.info(f\"Rank 0: Attempting to load from specified path: {resume_from_path}\")\n            path_to_load = resume_from_path\n\n        if path_to_load and os.path.exists(path_to_load):\n            return path_to_load\n        else:\n            logger.warning(\n                f\"Rank 0: Checkpoint path not found or invalid: '{path_to_load}'. \"\n                f\"Starting from scratch.\"\n            )\n            return None\n\n    def _load_model_states(self, global_step_folder: str) -> None:\n        \"\"\"Load model states for all trainable nodes.\"\"\"\n        loaded_worker_keys = set()\n\n        for node in self.taskgraph.nodes.values():\n            if node.node_type != NodeType.MODEL_TRAIN:\n                continue\n            if node.node_role not in [NodeRole.ACTOR, NodeRole.CRITIC]:\n                continue\n\n            node_worker_key = generate_node_worker_key(node)\n\n            if node_worker_key in loaded_worker_keys:\n                continue\n\n            worker = self.workers[node_worker_key]\n\n            sub_dir_name = f\"{node.node_role.name.lower()}_agent_{node.agent_group}\"\n            checkpoint_path = os.path.join(global_step_folder, sub_dir_name)\n\n            if os.path.exists(checkpoint_path):\n                worker.load_checkpoint(\n                    local_path=checkpoint_path,\n                    del_local_after_load=self.config.trainer.del_local_ckpt_after_load\n                )\n                loaded_worker_keys.add(node_worker_key)\n                logger.debug(\n                    f\"Rank {self.rank}: Loaded {node.node_role.name} checkpoint for agent \"\n                    f\"{node.agent_group} from {checkpoint_path}\"\n                )\n            else:\n                logger.warning(\n                    f\"Rank {self.rank}: Checkpoint for agent {node.agent_group}'s \"\n                    f\"{node.node_role.name} not found at {checkpoint_path}. \"\n                    f\"Weights will be from initialization. \"\n                    f\"If has multi-agent, will share the same checkpoint in agents\"\n                )\n\n    def _load_dataloader_state(self, global_step_folder: str) -> None:\n        \"\"\"Load dataloader state for current DP group.\"\"\"\n        _, dp_rank, _, _, _, _ = self._get_node_dp_info(self.first_rollout_node)\n        dataloader_path = os.path.join(global_step_folder, f\"data_dp_rank_{dp_rank}.pt\")\n\n        if os.path.exists(dataloader_path):\n            dataloader_state = torch.load(dataloader_path, map_location=\"cpu\")\n            self.dataloader.load_state_dict(dataloader_state)\n            logger.debug(\n                f\"Rank {self.rank} (DP_Rank {dp_rank}): Loaded dataloader state from \"\n                f\"{dataloader_path}\"\n            )\n        else:\n            logger.warning(\n                f\"Rank {self.rank} (DP_Rank {dp_rank}): Dataloader checkpoint not found at \"\n                f\"{dataloader_path}. Sampler state will not be restored, which may lead to \"\n                f\"data inconsistency.\"\n            )\n"
  },
  {
    "path": "siirl/dag_worker/constants.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Dict, List\nfrom siirl.execution.dag.node import NodeRole\n\n\nclass DAGInitializationError(Exception):\n    \"\"\"Custom exception for failures during DAGWorker initialization.\"\"\"\n\n    pass\n\n\nclass DAGConstants:\n    \"\"\"Centralized constants to improve maintainability and avoid magic strings.\"\"\"\n\n    # Worker role mapping\n    WORKER_ROLE_MAPPING: Dict[NodeRole, str] = {\n        NodeRole.ACTOR: \"actor\",\n        NodeRole.ROLLOUT: \"rollout\",\n        NodeRole.REFERENCE: \"ref\",\n    }\n    # Configuration keys\n    INTERN_CONFIG: str = \"intern_config\"\n    # Framework strategy names\n    FSDP_STRATEGIES: List[str] = [\"fsdp\", \"fsdp2\"]\n    MEGATRON_STRATEGYS: List[str] = [\"megatron\"]\n    # keep this for backward compatibility\n    MEGATRON_STRATEGY: str = \"megatron\"\n    # Metric group order\n    METRIC_GROUP_ORDER = [\"step\", \"training\", \"actor\", \"critic\", \"perf\", \"response_length\", \"response\", \"prompt_length\", \"prompt\", \"dapo_sampling\", \"global_seqlen\", \"timing_s\", \"timing_per_token_ms\", \"perf/total_num_tokens\", \"perf/time_per_step\", \"perf/throughput\"]\n"
  },
  {
    "path": "siirl/dag_worker/core_algos.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nCore functions to implement PPO algorithms.\nThe function implemented in this file should be used by trainer with different distributed strategies to\nimplement PPO-like algorithms.\n\"\"\"\n\n__all__ = [\"register_adv_est\", \"get_adv_estimator_fn\", \"AdvantageEstimator\"]\n\nimport math\nfrom collections import defaultdict\nfrom enum import Enum\nfrom typing import Any, Callable, Optional\n\nimport numpy as np\nimport torch\nfrom omegaconf import DictConfig\nfrom loguru import logger\nimport siirl.utils.model_utils.torch_functional as siirl_F\nfrom siirl.params.model_args import AlgorithmArguments, ActorArguments\nfrom siirl.execution.scheduler.enums import AdvantageEstimator\nfrom tensordict import TensorDict \n\n\nPolicyLossFn = Callable[\n    [\n        torch.Tensor,  # old_log_prob\n        torch.Tensor,  # log_prob\n        torch.Tensor,  # advantages\n        torch.Tensor,  # response_mask\n        str,  # loss_agg_mode\n        Optional[DictConfig | AlgorithmArguments],  # config\n        torch.Tensor | None,  # rollout_log_probs\n    ],\n    tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],\n]\n\nPOLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {}\n\n\ndef register_policy_loss(name: str) -> Callable[[PolicyLossFn], PolicyLossFn]:\n    \"\"\"Register a policy loss function with the given name.\n\n    Args:\n        name (str): The name to register the policy loss function under.\n\n    Returns:\n        function: Decorator function that registers the policy loss function.\n    \"\"\"\n\n    def decorator(func: PolicyLossFn) -> PolicyLossFn:\n        POLICY_LOSS_REGISTRY[name] = func\n        return func\n\n    return decorator\n\n\ndef get_policy_loss_fn(name):\n    \"\"\"Get the policy loss with a given name.\n\n    Args:\n        name: `(str)`\n            The name of the policy loss.\n\n    Returns:\n        `(callable)`: The policy loss function.\n    \"\"\"\n    loss_name = name\n    if loss_name not in POLICY_LOSS_REGISTRY:\n        raise ValueError(\n            f\"Unsupported loss mode: {loss_name}. Supported modes are: {list(POLICY_LOSS_REGISTRY.keys())}\"\n        )\n    return POLICY_LOSS_REGISTRY[loss_name]\n\n\ndef compute_response_mask(data: TensorDict):\n    \"\"\"Compute the attention mask for the response part of the sequence.\n    \n    Handles both 2D responses (NLP) and 3D responses (Embodied AI).\n    \n    Returns:\n        torch.Tensor: The attention mask for the response tokens (always 2D).\n    \"\"\"\n    responses = data[\"responses\"]\n    attention_mask = data[\"attention_mask\"]\n    batch_size = responses.size(0)\n    \n    # Handle 3D responses (Embodied AI): (batch_size, traj_len, action_token_len)\n    if responses.ndim == 3:\n        traj_len = responses.size(1)\n        action_token_len = responses.size(2)\n        \n        # Check if attention_mask is also 3D\n        if attention_mask.ndim == 3:\n            # attention_mask: (batch_size, traj_len, tot_pad_len)\n            # Extract response part from last dimension: (batch_size, traj_len, action_token_len)\n            response_mask = attention_mask[:, :, -action_token_len:]\n            # Flatten to 2D: (batch_size, traj_len * action_token_len)\n            response_mask = response_mask.reshape(batch_size, -1)\n        else:\n            # attention_mask is 2D: (batch_size, total_length)\n            # Calculate flattened response_length and slice\n            response_length = traj_len * action_token_len\n            response_mask = attention_mask[:, -response_length:]\n    # Handle 2D responses (NLP): (batch_size, response_length)\n    elif responses.ndim == 2:\n        response_length = responses.size(1)\n        response_mask = attention_mask[:, -response_length:]\n    else:\n        raise ValueError(f\"Unexpected responses shape: {responses.shape}, ndim={responses.ndim}\")\n    \n    return response_mask\n\nADV_ESTIMATOR_REGISTRY: dict[str, Any] = {}\n\n\ndef register_adv_est(name_or_enum: str | AdvantageEstimator) -> Any:\n    \"\"\"Decorator to register a advantage estimator function with a given name.\n\n    Args:\n        name_or_enum: `(str)` or `(AdvantageEstimator)`\n            The name or enum of the advantage estimator.\n\n    \"\"\"\n\n    def decorator(fn):\n        name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum\n        if name in ADV_ESTIMATOR_REGISTRY and ADV_ESTIMATOR_REGISTRY[name] != fn:\n            raise ValueError(\n                f\"Adv estimator {name} has already been registered: {ADV_ESTIMATOR_REGISTRY[name]} vs {fn}\"\n            )\n        ADV_ESTIMATOR_REGISTRY[name] = fn\n        return fn\n\n    return decorator\n\n\ndef get_adv_estimator_fn(name_or_enum):\n    \"\"\"Get the advantage estimator function with a given name.\n\n    Args:\n        name_or_enum: `(str)` or `(AdvantageEstimator)`\n            The name or enum of the advantage estimator.\n\n    Returns:\n        `(callable)`: The advantage estimator function.\n    \"\"\"\n    name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum\n    if name not in ADV_ESTIMATOR_REGISTRY:\n        raise ValueError(f\"Unknown advantage estimator simply: {name}\")\n    return ADV_ESTIMATOR_REGISTRY[name]\n\n\nclass AdaptiveKLController:\n    \"\"\"\n    Adaptive KL controller described in the paper:\n    https://arxiv.org/pdf/1909.08593.pdf\n    \"\"\"\n\n    def __init__(self, init_kl_coef, target_kl, horizon):\n        self.value = init_kl_coef\n        self.target = target_kl\n        self.horizon = horizon\n\n    def update(self, current_kl, n_steps):\n        \"\"\"Update the KL coefficient based on current KL divergence.\n\n        Args:\n            current_kl (float): Current KL divergence value.\n            n_steps (int): Number of steps taken.\n        \"\"\"\n        target = self.target\n        proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)\n        mult = 1 + proportional_error * n_steps / self.horizon\n        self.value *= mult\n\n\nclass FixedKLController:\n    \"\"\"Fixed KL controller.\"\"\"\n\n    def __init__(self, kl_coef):\n        self.value = kl_coef\n\n    def update(self, current_kl, n_steps):\n        \"\"\"Update method for fixed KL controller (no-op).\n\n        Args:\n            current_kl (float): Current KL divergence value (unused).\n            n_steps (int): Number of steps taken (unused).\n        \"\"\"\n        pass\n\n\ndef get_kl_controller(kl_ctrl):\n    \"\"\"Factory function to create appropriate KL controller based on configuration.\n\n    Args:\n        kl_ctrl: Configuration object containing KL controller settings.\n\n    Returns:\n        KL controller instance (FixedKLController or AdaptiveKLController).\n\n    Raises:\n        NotImplementedError: If controller type is not supported.\n        AssertionError: If adaptive controller horizon is not positive.\n    \"\"\"\n    if kl_ctrl.type == \"fixed\":\n        return FixedKLController(kl_coef=kl_ctrl.kl_coef)\n    elif kl_ctrl.type == \"adaptive\":\n        assert kl_ctrl.horizon > 0, f\"horizon must be larger than 0. Got {kl_ctrl.horizon}\"\n        return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)\n    else:\n        raise NotImplementedError\n\n\n@register_adv_est(AdvantageEstimator.GAE)  # or simply: @register_adv_est(\"gae\")\ndef compute_gae_advantage_return(\n    token_level_rewards: torch.Tensor,\n    values: torch.Tensor,\n    response_mask: torch.Tensor,\n    gamma: torch.Tensor,\n    lam: torch.Tensor,\n):\n    \"\"\"Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py\n\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape is (bs, response_length)\n        values: `(torch.Tensor)`\n            shape is (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.\n        gamma is `(float)`\n            discounted factor used in RL\n        lam: `(float)`\n            lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n\n    \"\"\"\n    with torch.no_grad():\n        nextvalues = 0\n        lastgaelam = 0\n        advantages_reversed = []\n        gen_len = token_level_rewards.shape[-1]\n        for t in reversed(range(gen_len)):\n            delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]\n            lastgaelam_ = delta + gamma * lam * lastgaelam\n\n            # skip values and TD-error on observation tokens\n            nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues\n            lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam\n\n            advantages_reversed.append(lastgaelam)\n        advantages = torch.stack(advantages_reversed[::-1], dim=1)\n\n        returns = advantages + values\n        advantages = siirl_F.masked_whiten(advantages, response_mask)\n    return advantages, returns\n\n\n@register_adv_est(AdvantageEstimator.GRPO)  # or simply: @register_adv_est(\"grpo\")\ndef compute_grpo_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n    index: np.ndarray,\n    epsilon: float = 1e-6,\n    norm_adv_by_std_in_grpo: bool = True,\n    config: Optional[AlgorithmArguments] = None,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute advantage for GRPO, operating only on Outcome reward\n    (with only one scalar reward for each response).\n\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape is (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape is (bs, response_length)\n        index: `(np.ndarray)`\n            index array for grouping\n        epsilon: `(float)`\n            small value to avoid division by zero\n        norm_adv_by_std_in_grpo: `(bool)`\n            whether to scale the GRPO advantage\n        config: `(Optional[AlgorithmArguments])`\n            algorithm configuration object\n\n    Note:\n        If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO.\n        If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape is (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape is (bs, response_length)\n    \"\"\"\n    scores = token_level_rewards.sum(dim=-1)\n\n    id2score = defaultdict(list)\n    id2mean = {}\n    id2std = {}\n\n    with torch.no_grad():\n        bsz = scores.shape[0]\n        for i in range(bsz):\n            if isinstance(index[i], torch.Tensor):\n                idx_key = index[i].item()\n            else:\n                idx_key = index[i]\n            id2score[idx_key].append(scores[i])\n\n        for idx in id2score:\n            if len(id2score[idx]) == 1:\n                id2mean[idx] = torch.tensor(0.0)\n                id2std[idx] = torch.tensor(1.0)\n            elif len(id2score[idx]) > 1:\n                scores_tensor = torch.stack(id2score[idx])\n                id2mean[idx] = torch.mean(scores_tensor)\n                id2std[idx] = torch.std(scores_tensor)\n            else:\n                raise ValueError(f\"no score in prompt index: {idx}\")\n\n        for i in range(bsz):\n            if isinstance(index[i], torch.Tensor):\n                idx_key = index[i].item()\n            else:\n                idx_key = index[i]\n            if norm_adv_by_std_in_grpo:\n                scores[i] = (scores[i] - id2mean[idx_key]) / (id2std[idx_key] + epsilon)\n            else:\n                scores[i] = scores[i] - id2mean[idx_key]\n        scores = scores.unsqueeze(-1) * response_mask\n\n    return scores, scores\n\n\ndef compute_marft_gae_advantage_return(\n    data: TensorDict,\n    pre_agent_group_ids,\n    gamma: torch.Tensor,\n    lam: torch.Tensor,\n):\n    \"\"\"\n    Args:\n        data: `TensorDict`\n        pre_agent_group_ids: `List`\n            pre agent id\n        gamma: `(float)`\n            discounted factor used in RL\n        lam: `(float)`\n            lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n\n    \"\"\"\n\n    key_prefix = \"agent_group_\"\n\n    async def compute_traj_adv():\n        pass\n\n    token_level_rewards = []\n    values = []\n    response_mask = []\n    advantages = []\n    returns = []\n    for agent_group in pre_agent_group_ids:\n        key = key_prefix + str(agent_group)\n        token_level_rewards.append(data.batch[key + \"_token_level_rewards\"])\n        values.append(data.batch[key + \"_values\"])\n        response_mask.append(data.batch[key + \"_response_mask\"])\n        advantages.append(torch.zeros_like(response_mask[-1]))\n        returns.append(torch.zeros_like(advantages[-1]))\n    token_level_rewards.append(data.batch[\"token_level_rewards\"])\n    values.append(data.batch[\"values\"])\n    response_mask.append(data.batch[\"response_mask\"])\n    advantages.append(torch.zeros_like(response_mask[-1]))\n    returns.append(torch.zeros_like(advantages[-1]))\n    pre_agent_group_ids.append(pre_agent_group_ids[-1] + 1)\n\n    with torch.no_grad():\n        seen = set()\n        dp_start_bs = [\n            i for i, s in enumerate(data.non_tensor_batch[\"request_id\"]) if s not in seen and not seen.add(s)\n        ]\n        # loop all batch_size\n        for bs_id in dp_start_bs:\n            # last agent last traj last token\n            gae = 0\n            traj_len = data.non_tensor_batch[\"traj_len\"][bs_id]\n            # loop each traj， tra has been reserved\n            for traj_idx in range(traj_len):\n                # loop each agent of traj\n                traj_bs_id = bs_id + traj_idx\n                # assert traj_idx == traj_len - data.non_tensor_batch[\"traj_step\"][traj_bs_id] - 1,\n                # f'traj_idx {traj_idx}, traj_bs_id: {traj_bs_id}, traj_step\n                # {data.non_tensor_batch[\"traj_step\"][traj_bs_id]}, traj_len {traj_len},\n                # request {data.non_tensor_batch[\"request_id\"][traj_bs_id]},request_data {data} '\n                for agent_idx in reversed(pre_agent_group_ids):\n                    gen_len = response_mask[agent_idx][traj_bs_id].sum()\n                    # loop each token of agent\n                    for t in reversed(range(gen_len)):\n                        rew = token_level_rewards[agent_idx][traj_bs_id, t]\n                        v = values[agent_idx][traj_bs_id, t]\n                        if agent_idx == pre_agent_group_ids[-1]:\n                            # last_agent\n                            if t == gen_len - 1:\n                                # last_token\n                                if traj_idx == 0:\n                                    v_next = 0\n                                else:\n                                    v_next = values[0][traj_bs_id - 1, 0]\n                                delta = rew + gamma * v_next - v\n                                gae = delta + gamma * lam * gae\n                            else:\n                                v_next = values[agent_idx][traj_bs_id, t + 1]\n                                delta = gamma * v_next - v\n                                gae = delta + gamma * lam * gae\n                        else:\n                            # not last agent\n                            if t == gen_len - 1:\n                                # last_token\n                                v_next = values[agent_idx + 1][traj_bs_id, 0]\n                                delta = rew + gamma * v_next - v\n                                gae = delta + gamma * lam * gae\n                            else:\n                                v_next = values[agent_idx][traj_bs_id, t + 1]\n                                delta = gamma * v_next - v\n                                gae = delta + gamma * lam * gae\n                        advantages[agent_idx][traj_bs_id, t] = gae\n                        returns[agent_idx][traj_bs_id, t] = gae + v\n        for agent_idx in pre_agent_group_ids:\n            advantages[agent_idx] = siirl_F.masked_whiten(advantages[agent_idx], response_mask[agent_idx])\n            if agent_idx != pre_agent_group_ids[-1]:\n                data.batch[key_prefix + str(agent_group) + \"_advantages\"] = advantages[agent_idx]\n                data.batch[key_prefix + str(agent_group) + \"_returns\"] = returns[agent_idx]\n            else:\n                data.batch[\"advantages\"] = advantages[agent_idx]\n                data.batch[\"returns\"] = returns[agent_idx]\n    return\n\n\ndef compute_cpgd_outcome_advantage(\n    token_level_rewards: torch.Tensor,\n    response_mask: torch.Tensor,\n    index: np.ndarray,\n    epsilon: float = 1e-6,\n    weight_factor_in_cpgd: str = \"STD_weight\",\n):\n    \"\"\"\n    Compute advantage for CPGD, operating only on Outcome reward\n    (with only one scalar reward for each response).\n    Args:\n        token_level_rewards: `(torch.Tensor)`\n            shape: (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape: (bs, response_length)\n        weight_factor_in_cpgd: (str)\n            whether to use the STD weight as GRPO or clip_filter_like_weight.\n            choices: {STD_weight, clip_filter_like_weight, naive}\n\n    Returns:\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        Returns: `(torch.Tensor)`\n            shape: (bs, response_length)\n    \"\"\"\n    scores = token_level_rewards.sum(dim=-1)\n\n    id2score = defaultdict(list)\n    id2mean = {}\n    id2std = {}\n\n    with torch.no_grad():\n        bsz = scores.shape[0]\n        for i in range(bsz):\n            id2score[index[i]].append(scores[i])\n        for idx in id2score:\n            if len(id2score[idx]) == 1:\n                id2mean[idx] = torch.tensor(0.0)\n                id2std[idx] = torch.tensor(1.0)\n            elif len(id2score[idx]) > 1:\n                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))\n                id2std[idx] = torch.std(torch.tensor([id2score[idx]]))\n            else:\n                raise ValueError(f\"no score in prompt index: {idx}\")\n        for i in range(bsz):\n            if weight_factor_in_cpgd == \"STD_weight\":\n                scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)\n            elif weight_factor_in_cpgd == \"clip_filter_like_weight\":\n                count_no_0_adv = sum(v != 0 for v in id2std.values())\n                scores[i] = (scores[i] - id2mean[index[i]]) * (bsz / count_no_0_adv).clamp(max=3.0)\n            elif weight_factor_in_cpgd == \"naive\":\n                scores[i] = scores[i] - id2mean[index[i]]\n            else:\n                raise NotImplementedError\n        scores = scores.unsqueeze(-1) * response_mask\n\n    return scores, scores\n\n\ndef compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):\n    \"\"\"Compute token-level rewards with KL penalty.\n\n    Args:\n        token_level_scores (torch.Tensor): Token-level reward scores.\n        old_log_prob (torch.Tensor): Log probabilities from current policy.\n        ref_log_prob (torch.Tensor): Log probabilities from reference policy.\n        kl_ratio (float): KL penalty coefficient.\n\n    Returns:\n        torch.Tensor: Token-level rewards with KL penalty applied.\n    \"\"\"\n    kl = old_log_prob - ref_log_prob\n    return token_level_scores - kl * kl_ratio\n\n\ndef agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str):\n    \"\"\"\n    Aggregate the loss matrix into a scalar.\n\n    Args:\n        loss_mat: `(torch.Tensor)`:\n            shape: (bs, response_length)\n        loss_mask: `(torch.Tensor)`:\n            shape: (bs, response_length)\n        loss_agg_mode: (str) choices:\n            method to aggregate the loss matrix into a scalar.\n    Returns:\n        loss: `a scalar torch.Tensor`\n            aggregated loss\n    \"\"\"\n    if loss_agg_mode == \"token-mean\":\n        loss = siirl_F.masked_mean(loss_mat, loss_mask)\n    elif loss_agg_mode == \"seq-mean-token-sum\":\n        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)  # token-sum\n        loss = torch.mean(seq_losses)  # seq-mean\n    elif loss_agg_mode == \"seq-mean-token-mean\":\n        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1)  # token-mean\n        loss = torch.mean(seq_losses)  # seq-mean\n    elif loss_agg_mode == \"seq-mean-token-sum-norm\":\n        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)\n        loss = torch.sum(seq_losses) / loss_mask.shape[-1]  # The divisor\n        # (loss_mask.shape[-1]) should ideally be constant\n        # throughout training to well-replicate the DrGRPO paper.\n        # TODO: Perhaps add user-defined normalizer argument to\n        # agg_loss to ensure divisor stays constant throughout.\n    else:\n        raise ValueError(f\"Invalid loss_agg_mode: {loss_agg_mode}\")\n\n    return loss\n\n\n@register_policy_loss(\"cpgd\")\ndef compute_policy_loss_cpgd(\n    old_log_prob: torch.Tensor,\n    log_prob: torch.Tensor,\n    advantages: torch.Tensor,\n    response_mask: torch.Tensor,\n    loss_agg_mode: str = \"token-mean\",\n    config: Optional[ActorArguments] = None,  # Use your config class\n    rollout_is_weights: torch.Tensor | None = None,  # Keep signature consistent, but unused in this function\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute the CPGD policy objective by directly clipping log_prob.\n    This function replicates the logic from the original siirl 'if use_cpgd_loss:' block.\n\n    Args:\n        old_log_prob: Log-probabilities under the old policy.\n        log_prob: Log-probabilities under the current policy.\n        advantages: Advantage estimates.\n        response_mask: Mask for valid tokens.\n        loss_agg_mode: Aggregation mode for the loss.\n        config: Configuration object containing clip ratios.\n        rollout_is_weights: Not used in this specific CPGD implementation.\n\n    Returns:\n        pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower\n    \"\"\"\n    assert config is not None, \"Config must be provided for CPGD loss\"\n\n    # --- Extract clip parameters from config ---\n    clip_ratio = config.clip_ratio\n    clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio\n    clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio\n    clip_ratio_c = (\n        config.clip_ratio_c if config.clip_ratio_c is not None else 3.0\n    )  # Needed only for pg_clipfrac_lower metric\n\n    assert clip_ratio_c > 1.0, f\"clip_ratio_c ({clip_ratio_c}) must be > 1.0\"\n\n    negative_approx_kl = log_prob - old_log_prob\n    ratio = torch.exp(negative_approx_kl)\n    ppo_kl = siirl_F.masked_mean(-negative_approx_kl, response_mask)\n\n    clipped_log_prob = torch.where(\n        advantages > 0,\n        torch.clamp(log_prob, max=math.log(1 + clip_ratio_high) + old_log_prob),\n        torch.clamp(log_prob, min=math.log(1 - clip_ratio_low) + old_log_prob),\n    )\n    pg_losses = -clipped_log_prob * advantages\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n    # Calculate clip fraction based on where the *ratio* would have been clipped\n    is_clipped = torch.where(advantages > 0, ratio > 1 + clip_ratio_high, ratio < 1 - clip_ratio_low)\n    pg_clipfrac = siirl_F.masked_mean(is_clipped.float(), response_mask).detach()\n\n    # Calculate lower clip fraction (dual clip metric)\n    pg_clipfrac_lower = siirl_F.masked_mean((ratio > clip_ratio_c) * (advantages < 0).float(), response_mask).detach()\n\n    return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower\n\ndef compute_policy_loss(\n    old_log_prob,\n    log_prob,\n    advantages,\n    response_mask,\n    cliprange=None,\n    cliprange_low=None,\n    cliprange_high=None,\n    clip_ratio_c=3.0,\n    loss_agg_mode: str = \"token-mean\",\n    use_cpgd_loss=False,\n):\n    \"\"\"\n    Compute the clipped policy objective and related metrics for PPO.\n\n    Adapted from\n    https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122\n\n    Args:\n        old_log_prob (torch.Tensor):\n            Log-probabilities of actions under the old policy, shape (batch_size, response_length).\n        log_prob (torch.Tensor):\n            Log-probabilities of actions under the current policy, shape (batch_size, response_length).\n        advantages (torch.Tensor):\n            Advantage estimates for each action, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).\n        cliprange (float, optional):\n            Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.\n            Defaults to None (must be provided).\n        cliprange_low (float, optional):\n            Lower clip range for dual-clip PPO. Defaults to same as `cliprange`.\n        cliprange_high (float, optional):\n            Upper clip range for dual-clip PPO. Defaults to same as `cliprange`.\n        clip_ratio_c (float, optional):\n            Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.\n            Defaults to 3.0.\n        loss_agg_mode (str, optional):\n            Aggregation mode for `agg_loss`. Defaults to \"token-mean\".\n        use_cpgd_loss (bool):\n            whter to use the CPGD loss\n    \"\"\"\n    assert clip_ratio_c > 1.0, \"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,\" + f\" but get the value: {clip_ratio_c}.\"\n\n    negative_approx_kl = log_prob - old_log_prob\n    ratio = torch.exp(negative_approx_kl)\n    ppo_kl = siirl_F.masked_mean(-negative_approx_kl, response_mask)\n\n    if cliprange_low is None:\n        cliprange_low = cliprange\n    if cliprange_high is None:\n        cliprange_high = cliprange\n\n    if use_cpgd_loss:\n        clipped_log_prob = torch.where(advantages > 0, torch.clamp(log_prob, max=math.log(1 + cliprange_high) + old_log_prob), torch.clamp(log_prob, min=math.log(1 - cliprange_low) + old_log_prob))\n        pg_losses = -clipped_log_prob * advantages\n        pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)  # use token-mean\n\n        is_clipped = torch.where(advantages > 0, ratio > 1 + cliprange_high, ratio < 1 - cliprange_low)\n        pg_clipfrac = siirl_F.masked_mean(is_clipped.float(), response_mask).detach()\n        pg_clipfrac_lower = siirl_F.masked_mean((ratio > clip_ratio_c) * (advantages < 0).float(), response_mask).detach()\n\n        return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower\n\n    pg_losses1 = -advantages * ratio\n    pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)  # - clip(ratio, 1-cliprange, 1+cliprange) * A\n    clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)  # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)\n    pg_clipfrac = siirl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)\n\n    pg_losses3 = -advantages * clip_ratio_c\n    clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)\n    pg_clipfrac_lower = siirl_F.masked_mean(torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask)\n\n    pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n    return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower\n\n\n@register_policy_loss(\"vanilla\")\ndef compute_policy_loss_vanilla(\n    old_log_prob: torch.Tensor,\n    log_prob: torch.Tensor,\n    advantages: torch.Tensor,\n    response_mask: torch.Tensor,\n    loss_agg_mode: str = \"token-mean\",\n    config: Optional[ActorArguments] = None,\n    rollout_is_weights: torch.Tensor | None = None,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute the clipped policy objective and related metrics for PPO.\n\n    Adapted from\n    https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122\n\n    Args:\n        old_log_prob (torch.Tensor):\n            Log-probabilities of actions under the old policy, shape (batch_size, response_length).\n        log_prob (torch.Tensor):\n            Log-probabilities of actions under the current policy, shape (batch_size, response_length).\n        advantages (torch.Tensor):\n            Advantage estimates for each action, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).\n        loss_agg_mode (str, optional):\n            Aggregation mode for `agg_loss`. Defaults to \"token-mean\".\n        config: `(ActorArguments)`:\n            config for the actor.\n        rollout_log_probs: `(torch.Tensor)`:\n            log probabilities of actions under the rollout policy, shape (batch_size, response_length).\n    \"\"\"\n\n    assert config is not None\n    assert not isinstance(config, AlgorithmArguments)\n    clip_ratio = config.clip_ratio  # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.\n    clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio\n    clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio\n    clip_ratio_c = config.clip_ratio_c\n\n    cliprange = clip_ratio\n    cliprange_low = clip_ratio_low\n    cliprange_high = clip_ratio_high\n\n    assert clip_ratio_c > 1.0, (\n        \"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,\"\n        + f\" but get the value: {clip_ratio_c}.\"\n    )\n\n    negative_approx_kl = log_prob - old_log_prob\n    # Clamp negative_approx_kl for stability\n    negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)\n    ratio = torch.exp(negative_approx_kl)\n    ppo_kl = siirl_F.masked_mean(-negative_approx_kl, response_mask)\n\n    pg_losses1 = -advantages * ratio\n    if cliprange_low is None:\n        cliprange_low = cliprange\n    if cliprange_high is None:\n        cliprange_high = cliprange\n    pg_losses2 = -advantages * torch.clamp(\n        ratio, 1 - cliprange_low, 1 + cliprange_high\n    )  # - clip(ratio, 1-cliprange, 1+cliprange) * A\n    clip_pg_losses1 = torch.maximum(\n        pg_losses1, pg_losses2\n    )  # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)\n    pg_clipfrac = siirl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)\n\n    pg_losses3 = -advantages * clip_ratio_c\n    clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)\n    pg_clipfrac_lower = siirl_F.masked_mean(\n        torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask\n    )\n\n    pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)\n\n    # Apply rollout importance sampling weights if provided\n    if rollout_is_weights is not None:\n        pg_losses = pg_losses * rollout_is_weights\n\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n    return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower\n\n\n@register_policy_loss(\"gspo\")\ndef compute_policy_loss_gspo(\n    old_log_prob: torch.Tensor,\n    log_prob: torch.Tensor,\n    advantages: torch.Tensor,\n    response_mask: torch.Tensor,\n    loss_agg_mode: str = \"seq-mean-token-mean\",\n    config: Optional[ActorArguments] = None,\n    rollout_is_weights: torch.Tensor | None = None,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute the clipped policy objective and related metrics for GSPO.\n\n    See https://arxiv.org/pdf/2507.18071 for more details.\n\n    Args:\n        old_log_prob (torch.Tensor):\n            Log-probabilities of actions under the old policy, shape (batch_size, response_length).\n        log_prob (torch.Tensor):\n            Log-probabilities of actions under the current policy, shape (batch_size, response_length).\n        advantages (torch.Tensor):\n            Advantage estimates for each action, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).\n        loss_agg_mode (str, optional):\n            Aggregation mode for `agg_loss`. For GSPO, it is recommended to use \"seq-mean-token-mean\".\n    \"\"\"\n\n    assert config is not None\n    assert isinstance(config, ActorArguments)\n    clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else config.clip_ratio\n    clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else config.clip_ratio\n\n    negative_approx_kl = log_prob - old_log_prob\n\n    # compute sequence-level importance ratio:\n    # si(θ) = (π_θ(yi|x)/π_θold(yi|x))^(1/|yi|) =\n    # exp [(1/|y_i|) * Σ_t log(π_θ(y_i,t|x,y_i,<t)/π_θold(y_i,t|x,y_i,<t))]\n    seq_lengths = torch.sum(response_mask, dim=-1).clamp(min=1)\n    negative_approx_kl_seq = torch.sum(negative_approx_kl * response_mask, dim=-1) / seq_lengths\n\n    # Combined ratio at token level:\n    # s_i,t(θ) = sg[s_i(θ)] · π_θ(y_i,t|x, y_i,<t) / sg[π_θ(y_i,t|x, y_i,<t)]\n    # In log space: log(s_i,t(θ)) = sg[log(s_i(θ))] + log_prob - sg[log_prob]\n    log_seq_importance_ratio = log_prob - log_prob.detach() + negative_approx_kl_seq.detach().unsqueeze(-1)\n    log_seq_importance_ratio = torch.clamp(log_seq_importance_ratio, max=10.0)  # clamp for numerical stability\n\n    # finally exp() to remove log\n    seq_importance_ratio = torch.exp(log_seq_importance_ratio)\n\n    pg_losses1 = -advantages * seq_importance_ratio\n    pg_losses2 = -advantages * torch.clamp(seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)\n    pg_losses = torch.maximum(pg_losses1, pg_losses2)\n\n    # Apply rollout importance sampling weights if provided\n    if rollout_is_weights is not None:\n        pg_losses = pg_losses * rollout_is_weights\n\n    # for GSPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean)\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=\"seq-mean-token-mean\")\n\n    # For compatibility, return zero for pg_clipfrac_lower (not used in standard GSPO)\n    pg_clipfrac = siirl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)\n    pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device)\n\n    ppo_kl = siirl_F.masked_mean(-negative_approx_kl, response_mask)\n\n    return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower\n\n\n@register_policy_loss(\"gpg\")\ndef compute_policy_loss_gpg(\n    old_log_prob: torch.Tensor,\n    log_prob: torch.Tensor,\n    advantages: torch.Tensor,\n    response_mask: torch.Tensor,\n    loss_agg_mode: str = \"token-mean\",\n    config: Optional[ActorArguments] = None,\n    rollout_is_weights: torch.Tensor | None = None,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"Adapted from\n    https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495\n    Args:\n        log_prob: `(torch.Tensor)`\n            shape: (bs, response_length)\n        advantages: `(torch.Tensor)`\n            shape: (bs, response_length)\n        response_mask: `(torch.Tensor)`\n            shape: (bs, response_length)\n    return:\n        pg_loss: `a scalar torch.Tensor`\n            policy gradient loss computed via GPG\n    \"\"\"\n    pg_losses = -log_prob * advantages\n\n    # Apply rollout importance sampling weights if provided\n    if rollout_is_weights is not None:\n        pg_losses = pg_losses * rollout_is_weights\n\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n    return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)\n\n\n@register_policy_loss(\"clip_cov\")\ndef compute_policy_loss_clip_cov(\n    old_log_prob: torch.Tensor,\n    log_prob: torch.Tensor,\n    advantages: torch.Tensor,\n    response_mask: torch.Tensor,\n    loss_agg_mode: str = \"token-mean\",\n    config: Optional[ActorArguments] = None,\n    rollout_is_weights: torch.Tensor | None = None,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute the clipped policy objective and related metrics for Clip-Cov.\n\n    Adapted from\n    https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py\n\n    Args:\n        old_log_prob (torch.Tensor):\n            Log-probabilities of actions under the old policy, shape (batch_size, response_length).\n        log_prob (torch.Tensor):\n            Log-probabilities of actions under the current policy, shape (batch_size, response_length).\n        advantages (torch.Tensor):\n            Advantage estimates for each action, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).\n        cliprange (float, optional):\n            Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.\n            Defaults to None (must be provided).\n        cliprange_low (float, optional):\n            Lower clip range for dual-clip PPO. Defaults to same as `cliprange`.\n        cliprange_high (float, optional):\n            Upper clip range for dual-clip PPO. Defaults to same as `cliprange`.\n        loss_agg_mode (str, optional):\n            Aggregation mode for `agg_loss`. Defaults to \"token-mean\".\n        clip_cvo_ratio (float, optional):\n            Ratio for clipping the covariance. Defaults to 0.0002.\n        clip_cov_lb (float, optional):\n            Lower bound for clipping covariance. Defaults to 1.0.\n        clip_cov_ub (float, optional):\n            Upper bound for clipping covariance. Defaults to 5.0.\n    \"\"\"\n    assert config is not None\n    assert not isinstance(config, ActorArguments), \"passing AlgoConfig not supported yet\"\n    assert config.policy_loss is not None\n\n    clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002\n    cliprange = config.clip_ratio\n    cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange\n    cliprange_high = config.clip_ratio_high if config.clip_ratio_high is not None else cliprange\n    clip_cov_ub = config.policy_loss.clip_cov_ub if config.policy_loss.clip_cov_ub is not None else 5.0\n    clip_cov_lb = config.policy_loss.clip_cov_lb if config.policy_loss.clip_cov_lb is not None else 1.0\n\n    assert clip_cov_ratio > 0, \"clip_ratio should be larger than 0.\"\n\n    negative_approx_kl = log_prob - old_log_prob\n    ratio = torch.exp(negative_approx_kl)\n    ppo_kl = siirl_F.masked_mean(-negative_approx_kl, response_mask)\n\n    pg_losses1 = -advantages * ratio\n\n    if cliprange_low is None:\n        cliprange_low = cliprange\n    if cliprange_high is None:\n        cliprange_high = cliprange\n\n    corr = torch.ones_like(advantages)\n    pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)\n    clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0)\n\n    cov_all = (advantages - siirl_F.masked_mean(advantages, response_mask)) * (\n        log_prob - siirl_F.masked_mean(log_prob.detach(), response_mask)\n    )\n    cov_all[response_mask == 0] = -torch.inf\n    cov_all[clip_by_origin] = -torch.inf\n\n    clip_num = max(int(clip_cov_ratio * response_mask.sum().item()), 1)\n    top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0)\n    top_k_idx = torch.nonzero(top_k_idx)\n\n    if len(top_k_idx) > 0:\n        perm = torch.randperm(len(top_k_idx))\n        top_k_idx = top_k_idx[perm[: min(clip_num, len(top_k_idx))]]\n    else:\n        top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long)\n\n    corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0\n\n    pg_clipfrac = siirl_F.masked_mean((corr == 0).float(), response_mask)\n\n    pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr\n\n    # Apply rollout importance sampling weights if provided\n    if rollout_is_weights is not None:\n        pg_losses = pg_losses * rollout_is_weights\n\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n    return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0)\n\n\n@register_policy_loss(\"kl_cov\")\ndef compute_policy_loss_kl_cov(\n    old_log_prob: torch.Tensor,\n    log_prob: torch.Tensor,\n    advantages: torch.Tensor,\n    response_mask: torch.Tensor,\n    loss_agg_mode: str = \"token-mean\",\n    config: Optional[ActorArguments] = None,\n    rollout_is_weights: torch.Tensor | None = None,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute the clipped policy objective and related metrics for Clip-Cov.\n\n    Adapted from\n    https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py\n\n    Args:\n        old_log_prob (torch.Tensor):\n            Log-probabilities of actions under the old policy, shape (batch_size, response_length).\n        log_prob (torch.Tensor):\n            Log-probabilities of actions under the current policy, shape (batch_size, response_length).\n        advantages (torch.Tensor):\n            Advantage estimates for each action, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).\n        loss_agg_mode (str, optional):\n            Aggregation mode for `agg_loss`. Defaults to \"token-mean\".\n        kl_cov_ratio (float, optional):\n            Ratio for selecting the top-k covariance values. Defaults to 0.0002.\n        ppo_kl_coef (float, optional):\n            Coefficient for the KL penalty term in the loss. Defaults to 1.\n    \"\"\"\n    assert config is not None\n    assert not isinstance(config, ActorArguments), \"passing AlgoConfig not supported yet\"\n    assert config.policy_loss is not None\n\n    kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002\n    ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0\n\n    assert kl_cov_ratio > 0, \"kl_cov_ratio should be larger than 0.\"\n\n    negative_approx_kl = log_prob - old_log_prob\n    abs_kl = negative_approx_kl.abs()\n    ratio = torch.exp(negative_approx_kl)\n    ppo_kl_abs = siirl_F.masked_mean(negative_approx_kl.abs(), response_mask)\n    pg_losses1 = -advantages * ratio\n    pg_losses_kl = -advantages * ratio + ppo_kl_coef * abs_kl\n    pg_losses = pg_losses1\n\n    all_valid = response_mask > 0\n    all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0]\n    all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu()\n    all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu()\n\n    k = min(kl_cov_ratio, len(all_valid_adv))\n\n    if k != 0:\n        cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean())\n        k_percent_nums = max(1, int(len(cov_lst_all) * kl_cov_ratio))\n        large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices\n\n        if len(large_cov_idxs) != 0:\n            large_cov_idxs = all_valid_idx[large_cov_idxs]\n            pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[\n                large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]\n            ]\n\n    # Apply rollout importance sampling weights if provided\n    if rollout_is_weights is not None:\n        pg_losses = pg_losses * rollout_is_weights\n\n    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n    return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0)\n\n\n@register_policy_loss(\"geo_mean\")\ndef compute_policy_loss_geo_mean(\n    old_log_prob: torch.Tensor,\n    log_prob: torch.Tensor,\n    advantages: torch.Tensor,\n    response_mask: torch.Tensor,\n    loss_agg_mode: str = \"token-mean\",\n    config: Optional[ActorArguments] = None,\n    rollout_is_weights: torch.Tensor | None = None,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute the clipped policy objective and related metrics for GMPO.\n\n    Adapted from paper https://arxiv.org/abs/2507.20673\n    https://github.com/callsys/GMPO/blob/main/train_zero_math_gmpo.py\n\n    Args:\n        old_log_prob (torch.Tensor):\n            Log-probabilities of actions under the old policy, shape (batch_size, response_length).\n        log_prob (torch.Tensor):\n            Log-probabilities of actions under the current policy, shape (batch_size, response_length).\n        advantages (torch.Tensor):\n            Advantage estimates for each action, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the loss, shape (batch_size, response_length).\n        loss_agg_mode (str, optional):\n            not used\n    \"\"\"\n\n    assert config is not None\n    assert not isinstance(config, ActorArguments)\n    clip_ratio = config.clip_ratio  # Clipping parameter. See https://arxiv.org/abs/1707.06347.\n    clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio\n    clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio\n\n    cliprange = clip_ratio\n    cliprange_low = clip_ratio_low\n    cliprange_high = clip_ratio_high\n    if cliprange_low is None:\n        cliprange_low = cliprange\n    if cliprange_high is None:\n        cliprange_high = cliprange\n\n    negative_approx_kl = log_prob - old_log_prob\n    # Clamp negative_approx_kl for stability (uncomment it if you like)\n    # negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)\n    ppo_kl = siirl_F.masked_mean(-negative_approx_kl, response_mask)\n\n    # Clipping at token-level & Clipping wider\n    sgn_advantage = torch.sign(advantages)\n    negative_approx_kl_clamp = torch.clamp(negative_approx_kl, -cliprange_low, cliprange_high)\n    negative_approx_kl_min = torch.min(sgn_advantage * negative_approx_kl, sgn_advantage * negative_approx_kl_clamp)\n    negative_approx_kl_min = sgn_advantage * negative_approx_kl_min\n\n    # Geometric-Mean Policy Optimization\n    response_mask_sum = response_mask.sum(dim=-1)\n    ratio = torch.exp((negative_approx_kl_min * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8))\n    # we only support sequence level advantage for now,\n    # otherwise, below would be not consistent with the paper\n    advantage = (advantages * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)\n    pg_losses = -advantage * ratio\n\n    # Apply rollout importance sampling weights if provided\n    # For geo_mean, IS weights are 2D (batch_size, seq_length) and need to be aggregated to sequence level\n    if rollout_is_weights is not None:\n        # Aggregate token-level weights to sequence level using geometric mean for consistency\n        # Note: rollout_is_weights is always 2D regardless of rollout_is_level\n        seq_is_weights = torch.exp(\n            (torch.log(rollout_is_weights + 1e-10) * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)\n        )\n        pg_losses = pg_losses * seq_is_weights\n\n    pg_loss = torch.mean(pg_losses)\n\n    # higher: ratio is too large that need clamp to clip_high (when adv > 0)\n    clipped = torch.ne(negative_approx_kl, negative_approx_kl_clamp)\n    pg_clipfrac = siirl_F.masked_mean((clipped * (advantages > 0)).float(), response_mask)\n    pg_clipfrac_lower = siirl_F.masked_mean((clipped * (advantages < 0)).float(), response_mask)\n\n    return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower\n\n\ndef compute_entropy_loss(logits, response_mask, loss_agg_mode: str = \"token-mean\"):\n    \"\"\"Compute categorical entropy loss (For backward compatibility)\n\n    Args:\n        logits (torch.Tensor): shape is (bs, response_length, vocab_size)\n        response_mask (torch.Tensor): shape is (bs, response_length)\n\n    Returns:\n        entropy: a scalar torch.Tensor\n\n    \"\"\"\n    # compute entropy\n    token_entropy = siirl_F.entropy_from_logits(logits)  # (bs, response_len)\n    entropy_loss = agg_loss(loss_mat=token_entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n    return entropy_loss\n\n\ndef compute_value_loss(\n    vpreds: torch.Tensor,\n    returns: torch.Tensor,\n    values: torch.Tensor,\n    response_mask: torch.Tensor,\n    cliprange_value: float,\n    loss_agg_mode: str = \"token-mean\",\n):\n    \"\"\"\n    Compute the clipped value-function loss for PPO.\n\n    Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151\n\n    Args:\n        vpreds (torch.FloatTensor):\n            Predicted values from the value head, shape (batch_size, response_length).\n        values (torch.FloatTensor):\n            Old (baseline) values from the value head, shape (batch_size, response_length).\n        returns (torch.FloatTensor):\n            Ground-truth returns, shape (batch_size, response_length).\n        response_mask (torch.Tensor):\n            Mask indicating which tokens to include in the value loss calculation.\n        cliprange_value (float):\n            Clip range for value prediction updates.\n        loss_agg_mode (str, optional):\n            Aggregation mode for `agg_loss`. Defaults to \"token-mean\".\n\n    Returns:\n        vf_loss (torch.FloatTensor):\n            A scalar tensor containing the aggregated value-function loss.\n        vf_clipfrac (float):\n            Fraction of elements where the clipped loss was used.\n    \"\"\"\n    vpredclipped = siirl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)\n    vf_losses1 = (vpreds - returns) ** 2\n    vf_losses2 = (vpredclipped - returns) ** 2\n    clipped_vf_losses = torch.max(vf_losses1, vf_losses2)\n    vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n    vf_clipfrac = siirl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)\n    return vf_loss, vf_clipfrac\n\n\ndef kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:\n    \"\"\"Compute KL divergence given logprob and ref_logprob. Optionally using straight through to bind k2 on other\n    kl penalty compute method for unbiased KL gradient estimation.\n    See more description in http://joschu.net/blog/kl-approx.html\n\n    Args:\n        logprob:\n        ref_logprob:\n\n    Returns:\n        kl_estimate\n    \"\"\"\n    forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty)\n    if not kl_penalty.endswith(\"+\") or kl_penalty in (\"mse\", \"k2\"):\n        return forward_score\n\n    \"\"\"\n    The expectation of k1 and k3 estimator is the expectaed value of KL, but the expected gradient of k1 and k3\n    estimator is not the expectaed gradient of KL. On the other hand k2 estimator gives right gradient estimator,\n    so we use a straight through trick here if the kl_penalty method ends with '+', .e.g., k3+.\n    \"\"\"\n    backward_score = 0.5 * (logprob - ref_logprob).square()\n\n    return backward_score - backward_score.detach() + forward_score.detach()\n\n\ndef kl_penalty_forward(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:\n    \"\"\"Compute KL divergence given logprob and ref_logprob.\n    Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104\n    See more description in http://joschu.net/blog/kl-approx.html\n\n    Args:\n        logprob:\n        ref_logprob:\n\n    Returns:\n        kl_estimate\n    \"\"\"\n    if kl_penalty in (\"kl\", \"k1\"):\n        return logprob - ref_logprob\n\n    if kl_penalty == \"abs\":\n        return (logprob - ref_logprob).abs()\n\n    if kl_penalty in (\"mse\", \"k2\"):\n        return 0.5 * (logprob - ref_logprob).square()\n\n    # J. Schulman. Approximating kl divergence, 2020.\n    # # URL http://joschu.net/blog/kl-approx.html.\n    if kl_penalty in (\"low_var_kl\", \"k3\"):\n        kl = ref_logprob - logprob\n        # For numerical stability\n        kl = torch.clamp(kl, min=-20, max=20)\n        ratio = torch.exp(kl)\n        kld = (ratio - kl - 1).contiguous()\n        return torch.clamp(kld, min=-10, max=10)\n\n    if kl_penalty == \"full\":\n        # so, here logprob and ref_logprob should contain the logits for every token in vocabulary\n        raise NotImplementedError\n\n    raise NotImplementedError\n\n\ndef compute_pf_ppo_reweight_data(\n    data,\n    reweight_method: str = \"pow\",\n    weight_pow: float = 2.0,\n):\n    \"\"\"Reweight the data based on the token_level_scores.\n\n    Args:\n        data: TensorDict object, containing batch, non_tensor_batch and meta_info\n        reweight_method: str, choices: \"pow\", \"max_min\", \"max_random\"\n        weight_pow: float, the power of the weight\n\n    Returns:\n\n    \"\"\"\n\n    @torch.no_grad()\n    def compute_weights(scores: torch.Tensor, reweight_method: str, weight_pow: float) -> torch.Tensor:\n        \"\"\"Compute importance weights for resampling based on scores.\n\n        Args:\n            scores (torch.Tensor): Tensor of scores to compute weights from.\n            reweight_method (str): Method for computing weights ('pow', 'max_min', 'max_random').\n            weight_pow (float): Power exponent for 'pow' method.\n\n        Returns:\n            torch.Tensor: Computed importance weights.\n\n        Raises:\n            ValueError: If reweight_method is not supported.\n        \"\"\"\n        if reweight_method == \"pow\":\n            weights = torch.pow(torch.abs(scores), weight_pow)\n        elif reweight_method == \"max_min\":\n            max_score = torch.max(scores)\n            min_score = torch.min(scores)\n            weights = torch.where((scores == max_score) | (scores == min_score), 1.0, 0.0)\n        elif reweight_method == \"max_random\":\n            max_score = torch.max(scores)\n            weights = torch.where(scores == max_score, 0.4, 0.1)\n        else:\n            raise ValueError(f\"Unsupported reweight_method: {reweight_method}\")\n        return weights\n\n    scores = data.batch[\"token_level_scores\"].sum(dim=-1)\n    weights = compute_weights(scores, reweight_method, weight_pow)\n    weights = torch.clamp(weights + 1e-8, min=1e-8)\n\n    batch_size = scores.shape[0]\n    sample_indices = torch.multinomial(weights, batch_size, replacement=True)\n\n    resampled_batch = {key: tensor[sample_indices] for key, tensor in data.batch.items()}\n\n    sample_indices_np = sample_indices.numpy()\n    resampled_non_tensor_batch = {}\n    for key, array in data.non_tensor_batch.items():\n        if isinstance(array, np.ndarray):\n            resampled_non_tensor_batch[key] = array[sample_indices_np]\n        else:\n            resampled_non_tensor_batch[key] = [array[i] for i in sample_indices_np]\n\n    resampled_meta_info = {}\n    for key, value in data.meta_info.items():\n        if isinstance(value, list) and len(value) == batch_size:\n            resampled_meta_info[key] = [value[i] for i in sample_indices_np]\n        else:\n            resampled_meta_info[key] = value\n\n    from copy import deepcopy\n\n    resampled_data = deepcopy(data)\n    resampled_data.batch = type(data.batch)(resampled_batch)\n    resampled_data.batch.batch_size = data.batch.batch_size\n    resampled_data.non_tensor_batch = resampled_non_tensor_batch\n    resampled_data.meta_info = resampled_meta_info\n\n    return resampled_data\n\n\ndef apply_kl_penalty(data: TensorDict, kl_ctrl: AdaptiveKLController, kl_penalty=\"kl\", multi_turn=False):\n    \"\"\"Apply KL penalty to the token-level rewards.\n\n    This function computes the KL divergence between the reference policy and current policy,\n    then applies a penalty to the token-level rewards based on this divergence.\n\n    Args:\n        data (TensorDict): The data containing batched model outputs and inputs.\n        kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty.\n        kl_penalty (str, optional): Type of KL penalty to apply. Defaults to \"kl\".\n        multi_turn (bool, optional): Whether the data is from a multi-turn conversation. Defaults to False.\n\n    Returns:\n        tuple: A tuple containing:\n            - The updated data with token-level rewards adjusted by KL penalty\n            - A dictionary of metrics related to the KL penalty\n    \"\"\"\n    responses = data[\"responses\"]\n    token_level_scores = data[\"token_level_scores\"]\n    batch_size = data.batch_size[0]\n    response_mask = data[\"response_mask\"]\n\n    # compute kl between ref_policy and current policy\n    # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled.\n    kld = kl_penalty(data[\"old_log_probs\"], data[\"ref_log_prob\"], kl_penalty=kl_penalty)  # (batch_size, response_length)\n    kld = kld * response_mask\n    beta = kl_ctrl.value\n\n    token_level_rewards = token_level_scores - beta * kld\n\n    current_kl = siirl_F.masked_mean(kld, mask=response_mask, axis=-1)  # average over sequence\n    current_kl = torch.mean(current_kl, dim=0).item()\n\n    # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837\n    kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)\n    data.batch[\"token_level_rewards\"] = token_level_rewards\n\n    metrics = {\"actor/reward_kl_penalty\": current_kl, \"actor/reward_kl_penalty_coeff\": beta}\n\n    return data, metrics\n\n\ndef compute_advantage(data: TensorDict, adv_estimator, gamma=1.0, lam=1.0, norm_adv_by_std_in_grpo=True, weight_factor_in_cpgd=\"STD_weight\", **kwargs):\n    \"\"\"Compute advantage estimates for policy optimization.\n\n    This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, CPGD, etc.\n    The advantage estimates are used to guide policy optimization in RL algorithms.\n\n    Args:\n        data (TensorDict): The data containing batched model outputs and inputs.\n        adv_estimator: The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++, CPGD).\n        gamma (float, optional): Discount factor for future rewards. Defaults to 1.0.\n        lam (float, optional): Lambda parameter for GAE. Defaults to 1.0.\n        num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1.\n        multi_turn (bool, optional): Whether the data is from a multi-turn conversation. Defaults to False.\n        norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in GRPO. Defaults to True.\n        weight_factor_in_cpgd (str, optional): whether to use the STD weight as GRPO or clip_filter_like_weight. choices: {STD_weight, clip_filter_like_weight, naive}\n\n    Returns:\n        TensorDict: The updated data with computed advantages and returns.\n    \"\"\"\n    # Back-compatible with trainers that do not compute response mask in fit\n    if \"response_mask\" not in data.keys():\n        data.batch[\"response_mask\"] = compute_response_mask(data)\n    # prepare response group\n    # TODO: add other ways to estimate advantages\n    if adv_estimator == AdvantageEstimator.GAE:\n        advantages, returns = compute_gae_advantage_return(\n            token_level_rewards=data[\"token_level_rewards\"],\n            values=data[\"values\"],\n            response_mask=data[\"response_mask\"],\n            gamma=gamma,\n            lam=lam,\n        )\n        data[\"advantages\"] = advantages\n        data[\"returns\"] = returns\n        if kwargs.get(\"use_pf_ppo\", False):\n            data = compute_pf_ppo_reweight_data(\n                data,\n                kwargs.get(\"pf_ppo_reweight_method\", \"pow\"),\n                kwargs.get(\"pf_ppo_weight_pow\", 2.0),\n            )\n    elif adv_estimator == AdvantageEstimator.GRPO:\n        if \"finish_step\" in data and data[\"responses\"].ndim == 3:\n            # Embodied scenario: compute mask based on finish_step\n            responses = data[\"responses\"]\n            batch_size = responses.size(0)\n            response_length = responses.size(1) * responses.size(2)  # traj_len * action_token_len\n            \n            # Get action_token_len from config or infer from responses shape\n            action_token_len = responses.size(2)  # action token length\n            finish_step = data['finish_step'] * action_token_len\n            \n            steps = torch.arange(response_length, device=responses.device)\n            steps_expanded = steps.unsqueeze(0).expand(batch_size, -1)\n            grpo_calculation_mask = steps_expanded < finish_step.unsqueeze(1)  # (batch_size, traj_len)\n            \n            logger.info(f\"[GRPO] Using finish_step-based mask for embodied scenario\")\n        else:\n            # NLP scenario or no finish_step: use attention_mask-based response_mask\n            grpo_calculation_mask = data[\"response_mask\"]\n            logger.info(f\"[GRPO] Using attention_mask-based response_mask for NLP scenario\")\n        # Call compute_grpo_outcome_advantage with parameters matching its definition\n        advantages, returns = compute_grpo_outcome_advantage(\n            token_level_rewards=data[\"token_level_rewards\"],\n            response_mask=grpo_calculation_mask,\n            index=data[\"uid\"],\n            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,\n        )\n        data[\"advantages\"] = advantages\n        data[\"returns\"] = returns\n        # Store the mask for consistent metrics calculation\n        data[\"response_mask\"] = grpo_calculation_mask\n        logger.debug(f\"[GRPO] Stored response_mask in batch for consistent metrics\")\n    elif adv_estimator == AdvantageEstimator.CPGD:\n        cpgd_calculation_mask = data[\"response_mask\"]\n        # Call compute_cpgd_outcome_advantage with parameters matching its definition\n        advantages, returns = compute_grpo_outcome_advantage(\n            token_level_rewards=data[\"token_level_rewards\"],\n            response_mask=cpgd_calculation_mask,\n            index=data[\"uid\"],\n            weight_factor_in_cpgd=weight_factor_in_cpgd,\n        )\n        data[\"advantages\"] = advantages\n        data[\"returns\"] = returns\n    elif adv_estimator == AdvantageEstimator.GAE_MARFT:\n        compute_marft_gae_advantage_return(\n            data,\n            pre_agent_group_ids=kwargs[\"agent_group_ids\"],\n            gamma=gamma,\n            lam=lam,\n        )\n    else:\n        raise NotImplementedError\n    return data\n"
  },
  {
    "path": "siirl/dag_worker/dag_utils.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nUtility functions for DAG worker operations.\n\"\"\"\n\nimport os\nimport ray\nimport torch\nimport inspect\nimport json\nimport time\nimport csv\nimport hashlib\nimport numpy as np\nimport torch.distributed as dist\nfrom contextlib import contextmanager\nfrom datetime import datetime\nfrom zoneinfo import ZoneInfo\nfrom pathlib import Path\nfrom collections import deque\nfrom tensordict import TensorDict\nfrom typing import Dict, Optional, Type, List, Any, Tuple, Union\nfrom loguru import logger\nfrom tensordict import TensorDict\n\nfrom siirl.execution.dag.node import Node, NodeType, NodeRole\nfrom siirl.execution.dag import TaskGraph\nfrom siirl.utils.extras.device import get_device_name, device_synchronize\nfrom siirl.engine.base_worker import Worker\nfrom siirl.utils.import_string import import_string\nfrom siirl.dag_worker.constants import DAGConstants\nfrom siirl.dag_worker.data_structures import ValidationResult\nfrom siirl.dag_worker.metric_aggregator import (\n    DistributedMetricAggregator,\n    _ReduceOp\n)\n\n\n# ==========================================================================================\n# Section 1: Performance & Timing\n# ==========================================================================================\n\n@contextmanager\ndef timer(enable_perf: bool, name: str, timing_dict: dict):\n    \"\"\"Measures execution time of a code block and stores in timing_dict.\"\"\"\n    if enable_perf:\n        device_synchronize()\n    start_time = time.perf_counter()\n    yield\n    if enable_perf:\n        device_synchronize()\n    end_time = time.perf_counter()\n    timing_dict[name] = timing_dict.get(name, 0) + end_time - start_time\n\n\ndef add_prefix_to_dataproto(tensordict: TensorDict, node: Node):\n    \"\"\"\n    Adds a prefix to all keys in the TensorDict.\n    The prefix is formatted as f\"agent_group_{node.agent_group}_\".\n    Only keys that do not already have a prefix will be modified.\n\n    Args:\n        data_proto (TensorDict): The TensorDict instance.\n        node (Node): The node containing the agent_group.\n    \"\"\"\n    prefix = f\"agent_group_{node.agent_group}_\"\n    prefix_agent_group = \"agent_group_\"\n\n    # Process tensor batch\n    if tensordict is not None:\n        new_batch = {}\n        for key, value in tensordict.items():\n            if not key.startswith(prefix_agent_group):\n                new_key = prefix + key\n                new_batch[new_key] = value\n            else:\n                new_batch[key] = value\n        tensordict = TensorDict(new_batch, batch_size=tensordict.batch_size)\n    return tensordict\n\n\ndef remove_prefix_from_dataproto(tensordict, node: Node):\n    \"\"\"\n    Removes the prefix from all keys in the TensorDict.\n    Only keys with a matching prefix will have the prefix removed.\n\n    Args:\n        data_proto (TensorDict): The TensorDict instance.\n        node (Node): The node containing the agent_group to identify the prefix.\n    \"\"\"\n    prefix = f\"agent_group_{node.agent_group}_\"\n    prefix_len = len(prefix)\n\n    # Process tensor batch\n    if tensordict is not None:\n        new_batch = {}\n        for key, value in tensordict.items():\n            if key.startswith(prefix):\n                new_key = key[prefix_len:]\n                new_batch[new_key] = value\n            else:\n                new_batch[key] = value\n        tensordict = TensorDict(new_batch, batch_size=tensordict.batch_size)\n\n    return tensordict\n\n\ndef add_prefix_to_metrics(metrics: dict, node: Node) -> dict:\n    \"\"\"Adds agent prefix to all metric keys for multi-agent isolation.\"\"\"\n    prefix = f\"agent_{node.agent_group}_\"\n    prefix_agent_group = \"agent_\"\n    if metrics:\n        new_metrics = {}\n        for key, value in metrics.items():\n            if not key.startswith(prefix_agent_group):\n                new_key = prefix + key\n                new_metrics[new_key] = value\n            else:\n                new_metrics[key] = value\n        metrics = new_metrics\n    return metrics\n\n\n# ==========================================================================================\n# Section 3: Initialization & Setup\n# ==========================================================================================\n\ndef get_and_validate_rank() -> int:\n    \"\"\"Retrieves and validates worker rank from RANK environment variable.\"\"\"\n    rank_str = os.environ.get(\"RANK\")\n    if rank_str is None:\n        raise ValueError(\"Environment variable 'RANK' is not set. This is required for distributed setup.\")\n    try:\n        return int(rank_str)\n    except ValueError as e:\n        raise ValueError(f\"Invalid RANK format: '{rank_str}'. Must be an integer.\") from e\n\n\ndef get_taskgraph_for_rank(rank: int, taskgraph_mapping: Dict[int, TaskGraph]) -> TaskGraph:\n    \"\"\"Retrieves TaskGraph for current rank from mapping.\"\"\"\n    if rank not in taskgraph_mapping:\n        raise ValueError(f\"Rank {rank} not found in the provided taskgraph_mapping.\")\n    taskgraph = taskgraph_mapping[rank]\n\n    if not isinstance(taskgraph, TaskGraph):\n        raise TypeError(f\"Object for rank {rank} must be a TaskGraph, but got {type(taskgraph).__name__}.\")\n    logger.info(f\"Rank {rank} assigned to TaskGraph with ID {taskgraph.graph_id}.\")\n    return taskgraph\n\n\ndef log_ray_actor_info(rank: int):\n    \"\"\"Logs Ray actor context information for debugging.\"\"\"\n    try:\n        ctx = ray.get_runtime_context()\n        logger.debug(\n            f\"Ray Actor Context for Rank {rank}: ActorID={ctx.get_actor_id()}, JobID={ctx.get_job_id()}, \"\n            f\"NodeID={ctx.get_node_id()}, PID={os.getpid()}\"\n        )\n    except RuntimeError:\n        logger.warning(f\"Rank {rank}: Not running in a Ray actor context.\")\n\n\ndef log_role_worker_mapping(role_worker_mapping: Dict[NodeRole, Type[Worker]]):\n    \"\"\"Logs role-to-worker class mapping for verification.\"\"\"\n    if not role_worker_mapping:\n        logger.error(\"Role-to-worker mapping is empty after setup. This will cause execution failure.\")\n        return\n\n    logger.debug(\"--- [Role -> Worker Class] Mapping ---\")\n    max_len = max((len(r.name) for r in role_worker_mapping.keys()), default=0)\n    for role, worker_cls in sorted(role_worker_mapping.items(), key=lambda item: item[0].name):\n        logger.debug(\n            f\"  {role.name:<{max_len}} => {worker_cls.__name__} (from {inspect.getmodule(worker_cls).__name__})\"\n        )\n    logger.debug(\"--------------------------------------\")\n\n\n# ==========================================================================================\n# Section 4: Worker Management\n# ==========================================================================================\n\ndef find_first_non_compute_ancestor(taskgraph: TaskGraph, start_node_id: str) -> Optional[Node]:\n    \"\"\"Finds first ancestor node that is not COMPUTE type using BFS.\"\"\"\n    start_node = taskgraph.get_node(start_node_id)\n    if not start_node:\n        logger.warning(f\"Could not find start node '{start_node_id}' in the graph.\")\n        return None\n\n    if start_node.node_type != NodeType.COMPUTE:\n        return start_node\n\n    queue = deque(start_node.dependencies)\n    visited = set(start_node.dependencies)\n    node_id = start_node_id\n\n    while queue:\n        logger.debug(f\"try find dependency node with ID '{node_id}' during upward search\")\n        node_id = queue.popleft()\n        node = taskgraph.get_node(node_id)\n\n        if not node:\n            logger.warning(f\"Could not find dependency node with ID '{node_id}' during upward search.\")\n            continue\n\n        if node.node_type != NodeType.COMPUTE:\n            return node\n\n        for dep_id in node.dependencies:\n            if dep_id not in visited:\n                visited.add(dep_id)\n                queue.append(dep_id)\n    return None\n\n\ndef should_create_worker(role_worker_mapping: Dict[NodeRole, Type[Worker]], node: Node) -> bool:\n    \"\"\"Determines if worker instance should be created for a given node.\"\"\"\n    if node.agent_options and node.agent_options.share_instance:\n        # Worker already initialized in target agent node\n        return False\n    return node.node_type in [NodeType.MODEL_TRAIN, NodeType.MODEL_INFERENCE] and node.node_role in role_worker_mapping\n\n\ndef generate_node_worker_key(node: Node) -> str:\n    \"\"\"Generates unique key for node's worker instance.\"\"\"\n    return f\"{node.agent_group}_{node.node_type.value}_{node.node_role.value}\"\n\n\ndef setup_sharding_manager(\n    config,\n    agent_group_process_group: Dict,\n    agent_group: int,\n    worker_dict: Dict[NodeRole, Worker]\n):\n    \"\"\"Configures sharding manager to sync weights between training and inference backends.\"\"\"\n    actor_worker = worker_dict[NodeRole.ACTOR]\n    rollout_worker = worker_dict[NodeRole.ROLLOUT]\n    rollout_pg = agent_group_process_group[agent_group][NodeRole.ROLLOUT]\n    \n    if config.actor_rollout_ref.model.model_type == \"embodied\":\n        if hasattr(actor_worker, \"actor_module_fsdp\"):\n            rollout_worker.rollout.model = actor_worker.actor_module_fsdp\n            logger.info(f\"[Embodied] Set module for EmbodiedHFRollout for agent group {agent_group}.\")\n        else:\n            logger.error(f\"[Embodied] Actor worker for agent group {agent_group} does not have 'actor_module_fsdp'.\")\n\n    rollout_pg = agent_group_process_group[agent_group][NodeRole.ROLLOUT]\n    \n    parallel_config = {\n        \"rollout_parallel_size\": rollout_worker.config.rollout.tensor_model_parallel_size,\n        \"rollout_world_size\": dist.get_world_size(rollout_pg),\n        \"rollout_rank\": dist.get_rank(rollout_pg),\n    }\n\n    device_name = get_device_name()\n    layer_name_mapping = {\n        \"qkv_layer_name\": \"self_attention.linear_qkv.\",\n        \"gate_proj_layer_name\": \"linear_fc1.weight\",\n    }\n\n    # Lazy import and deferred execution mapping\n    sharding_manager_map = {\n        (\"fsdp\", \"hf\"): (\n                \"siirl.engine.sharding_manager.fsdp_hf.FSDPHFShardingManager\",\n                lambda: {\n                    \"module\": actor_worker.actor_module_fsdp,\n                    \"rollout\": rollout_worker.rollout,\n                    \"offload_param\": getattr(actor_worker, \"_is_offload_param\", False),\n                    \"offload_embedding\": (\n                        getattr(rollout_worker.config, \"embodied\", None) is not None and\n                        getattr(rollout_worker.config.embodied, \"embedding_model_offload\", False)),\n            },\n        ),\n        (\"fsdp\", \"vllm\"): (\n            \"siirl.engine.sharding_manager.fsdp_vllm.MultiAgentFSDPVLLMShardingManager\",\n            lambda: {\n                \"module\": actor_worker.actor_module_fsdp,\n                \"inference_engine\": rollout_worker.rollout.inference_engine,\n                \"model_config\": actor_worker.actor_model_config,\n                \"parallel_config\": parallel_config,\n                \"full_params\": \"hf\" in rollout_worker.config.rollout.load_format,\n                \"offload_param\": getattr(actor_worker, \"_is_offload_param\", False),\n            },\n        ),\n        (\"fsdp\", \"sglang\"): (\n            \"siirl.engine.sharding_manager.fsdp_sglang.MultiAgentFSDPSGLangShardingManager\",\n            lambda: {\n                \"module\": actor_worker.actor_module_fsdp,\n                \"inference_engine\": rollout_worker.rollout.inference_engine,\n                \"model_config\": actor_worker.actor_model_config,\n                \"device_mesh\": torch.distributed.init_device_mesh(\n                    device_name,\n                    mesh_shape=(\n                        parallel_config.get(\"rollout_world_size\") // parallel_config.get(\"rollout_parallel_size\"),\n                        parallel_config.get(\"rollout_parallel_size\"),\n                    ),\n                    mesh_dim_names=[\"dp\", \"infer_tp\"],\n                ),\n                \"rollout_config\": rollout_worker.config.rollout,\n                \"full_params\": \"hf\" in rollout_worker.config.rollout.load_format,\n                \"offload_param\": getattr(actor_worker, \"_is_offload_param\", False),\n                \"multi_stage_wake_up\": rollout_worker.config.rollout.multi_stage_wake_up,\n            },\n        ),\n        (\"megatron\", \"vllm\"): (\n            \"siirl.engine.sharding_manager.megatron_vllm.MultiAgentMegatronVLLMShardingManager\",\n            lambda: {\n                \"actor_module\": actor_worker.actor_module,\n                \"inference_engine\": rollout_worker.rollout.inference_engine,\n                \"model_config\": actor_worker.actor_model_config,\n                \"rollout_config\": rollout_worker.config.rollout,\n                \"transformer_config\": actor_worker.tf_config,\n                \"layer_name_mapping\": layer_name_mapping,\n                \"weight_converter\": get_mcore_weight_converter(actor_worker.actor_model_config, actor_worker.dtype),\n                \"device_mesh\": rollout_worker.device_mesh,\n                \"offload_param\": actor_worker._is_offload_param,\n                \"bridge\": actor_worker.bridge,\n            },\n        ),\n        (\"megatron\", \"sglang\"): (\n            \"siirl.engine.sharding_manager.megatron_sglang.MultiAgentMegatronSGLangShardingManager\",\n            lambda: {\n                \"actor_module\": actor_worker.actor_module,\n                \"inference_engine\": rollout_worker.rollout.inference_engine,\n                \"model_config\": actor_worker.actor_model_config,\n                \"rollout_config\": rollout_worker.config.rollout,\n                \"transformer_config\": actor_worker.tf_config,\n                \"layer_name_mapping\": layer_name_mapping,\n                \"weight_converter\": get_mcore_weight_converter(actor_worker.actor_model_config, actor_worker.dtype),\n                \"device_mesh\": torch.distributed.init_device_mesh(\n                    device_name,\n                    mesh_shape=(\n                        parallel_config.get(\"rollout_world_size\") // parallel_config.get(\"rollout_parallel_size\"),\n                        parallel_config.get(\"rollout_parallel_size\"),\n                    ),\n                    mesh_dim_names=[\"dp\", \"infer_tp\"],\n                ),\n                \"offload_param\": getattr(actor_worker, \"_is_offload_param\", False),\n                \"bridge\": actor_worker.bridge,\n            },\n        ),\n    }\n\n    strategy = actor_worker.config.actor.strategy.lower()\n    if strategy == DAGConstants.MEGATRON_STRATEGY:\n        from siirl.models.mcore import get_mcore_weight_converter\n    rollout_name = config.actor_rollout_ref.rollout.name.lower()\n    if (strategy, rollout_name) not in sharding_manager_map:\n        raise NotImplementedError(f\"Unsupported sharding manager configuration: {strategy=}, {rollout_name=}\")\n\n    sharding_manager_cls_str, kwargs_builder = sharding_manager_map[(strategy, rollout_name)]\n    sharding_manager_cls = import_string(sharding_manager_cls_str)\n    sharding_manager = sharding_manager_cls(**kwargs_builder())\n    rollout_worker.set_rollout_sharding_manager(sharding_manager)\n    logger.debug(f\"Set up {sharding_manager_cls.__name__}  for agent group {agent_group}.\")\n\n\ndef get_worker_classes(config, strategy: str) -> Dict[NodeRole, Type[Worker]]:\n    \"\"\"Dynamically imports worker classes based on specified training strategy.\"\"\"\n    if strategy in DAGConstants.FSDP_STRATEGIES:\n        from siirl.engine.fsdp_workers import (\n            ActorRolloutRefWorker,\n            AsyncActorRolloutRefWorker,\n            CriticWorker,\n            RewardModelWorker,\n        )\n\n        actor_cls = (\n            AsyncActorRolloutRefWorker\n            if config.actor_rollout_ref.rollout.mode == \"async\"\n            else ActorRolloutRefWorker\n        )\n        return {\n            NodeRole.ACTOR: actor_cls,\n            NodeRole.ROLLOUT: actor_cls,\n            NodeRole.REFERENCE: actor_cls,\n            NodeRole.CRITIC: CriticWorker,\n            NodeRole.REWARD: RewardModelWorker,\n        }\n    elif strategy in DAGConstants.MEGATRON_STRATEGYS:\n        from siirl.engine.megatron_workers import (\n            ActorWorker,\n            RolloutWorker,\n            AsyncRolloutWorker,\n            ReferenceWorker,\n            CriticWorker,\n            RewardModelWorker\n        )\n\n        is_async_mode = config.actor_rollout_ref.rollout.mode == \"async\"\n\n        return {\n            NodeRole.ACTOR: ActorWorker,\n            NodeRole.ROLLOUT: AsyncRolloutWorker if is_async_mode else RolloutWorker,\n            NodeRole.REFERENCE: ReferenceWorker,\n            NodeRole.CRITIC: CriticWorker,\n            NodeRole.REWARD: RewardModelWorker\n        }\n    raise NotImplementedError(f\"Strategy '{strategy}' is not supported.\")\n\n\ndef get_parallelism_config(reference_node: Node) -> tuple[int, int]:\n    \"\"\"Extracts tensor parallel (TP) and pipeline parallel (PP) sizes from node config.\"\"\"\n    tp_size = 1\n    pp_size = 1\n\n    if intern_config := reference_node.config.get(DAGConstants.INTERN_CONFIG):\n        if reference_node.node_type == NodeType.MODEL_INFERENCE:\n            # Rollout nodes: only TP supported (PP not typically used for inference)\n            tp_size = intern_config.rollout.tensor_model_parallel_size\n            pp_size = 1\n\n        elif reference_node.node_type == NodeType.MODEL_TRAIN:\n            # Extract strategy from config\n            strategy = 'fsdp'  # default\n\n            if hasattr(intern_config, 'actor') and hasattr(intern_config.actor, 'strategy'):\n                strategy = intern_config.actor.strategy\n            elif hasattr(intern_config, 'strategy'):\n                strategy = intern_config.strategy\n\n            if strategy in DAGConstants.MEGATRON_STRATEGYS:\n                # Megatron supports both TP and PP\n                if hasattr(intern_config, 'actor') and hasattr(intern_config.actor, 'megatron'):\n                    tp_size = intern_config.actor.megatron.tensor_model_parallel_size\n                    pp_size = intern_config.actor.megatron.pipeline_model_parallel_size\n                elif hasattr(intern_config, 'megatron'):\n                    tp_size = intern_config.megatron.tensor_model_parallel_size\n                    pp_size = intern_config.megatron.pipeline_model_parallel_size\n            else:\n                # FSDP: no TP/PP, keep TP=PP=1\n                tp_size = 1\n                pp_size = 1\n\n    return tp_size, pp_size\n\n\ndef prepare_generation_batch(batch: TensorDict) -> TensorDict:\n    \"\"\"Pops keys from a batch to isolate data needed for sequence generation.\"\"\"\n    keys_to_pop = [\"input_ids\", \"attention_mask\", \"position_ids\", \"raw_prompt_ids\"]\n    if \"multi_modal_inputs\" in batch:\n        keys_to_pop.extend([\"multi_modal_data\", \"multi_modal_inputs\"])\n    if \"tools_kwargs\" in batch:\n        keys_to_pop.append(\"tools_kwargs\")\n    if \"raw_prompt\" in batch:\n        keys_to_pop.append(\"raw_prompt\")\n    if \"interaction_kwargs\" in batch:\n        keys_to_pop.append(\"interaction_kwargs\")\n    return batch.pop(\n    )\n\n\ndef prepare_local_batch_metrics(batch: TensorDict, use_critic: bool = True) -> Dict[str, torch.Tensor]:\n    \"\"\"Extracts raw metric tensors from batch for distributed aggregation.\"\"\"\n    from siirl.utils.metrics.metric_utils import _compute_response_info\n\n    response_info = _compute_response_info(batch)\n    response_mask = response_info[\"response_mask\"].bool()\n    device = batch[\"advantages\"].device\n    max_response_length = batch[\"responses\"].shape[-1]\n    response_lengths = response_info[\"response_length\"].to(device)\n    prompt_lengths = response_info[\"prompt_length\"].to(device)\n\n    # Components for correct/wrong response length metrics\n    correct_threshold = 0.5\n    rewards_per_response = batch[\"token_level_rewards\"].sum(-1)\n    correct_mask = rewards_per_response > correct_threshold\n\n    # Components for prompt clip ratio\n    prompt_attn_mask = batch[\"attention_mask\"][:, :-max_response_length]\n    max_prompt_length = prompt_attn_mask.size(-1)\n\n    # Prepare raw metric values\n    local_data = {\n        \"score\": batch[\"token_level_scores\"].sum(-1),\n        \"rewards\": batch[\"token_level_rewards\"].sum(-1),\n        \"advantages\": torch.masked_select(batch[\"advantages\"], response_mask),\n        \"returns\": torch.masked_select(batch[\"returns\"], response_mask),\n        \"response_length\": response_info[\"response_length\"].to(device),\n        \"prompt_length\": response_info[\"prompt_length\"].to(device),\n        \"correct_response_length\": response_lengths[correct_mask],\n        \"wrong_response_length\": response_lengths[~correct_mask],\n        \"response_clip_ratio\": torch.eq(response_info[\"response_length\"], max_response_length).float(),\n        \"prompt_clip_ratio\": torch.eq(prompt_lengths, max_prompt_length).float(),\n    }\n\n    if use_critic:\n        valid_values = torch.masked_select(batch[\"values\"], response_mask)\n        error = local_data[\"returns\"] - valid_values\n\n        critic_data = {\n            \"values\": valid_values,\n            # Special components for explained variance (summed globally)\n            \"returns_sq_sum_comp\": torch.sum(torch.square(local_data[\"returns\"])),\n            \"error_sum_comp\": torch.sum(error),\n            \"error_sq_sum_comp\": torch.sum(torch.square(error)),\n        }\n        local_data.update(critic_data)\n\n    return local_data\n\n\ndef whether_put_data(rank, is_current_last_pp_tp_rank0, next_dp_size, cur_dp_size, cur_node, next_node) -> bool:\n    # Determine whether to put data into buffer based on node configuration\n    result = False\n    reason = \"No condition met\"\n    \n    if is_current_last_pp_tp_rank0:\n        result = True\n        reason = \"Current last PP rank's TP rank 0\"\n    elif next_dp_size == cur_dp_size:\n        if next_node.node_type in [NodeType.COMPUTE, NodeType.MODEL_TRAIN]:\n            result = True\n            reason = f\"DP sizes match and next node is {next_node.node_type}\"\n    elif cur_node.node_role == next_node.node_role and cur_node.node_role == NodeRole.ROLLOUT:\n        result = True\n        reason = \"Both nodes are ROLLOUT\"\n        \n    logger.debug(f\"Rank {rank}: _whether_put_data decision for {cur_node.node_id}->{next_node.node_id}: {result} ({reason}). \"\n                f\"is_current_last_pp_tp_rank0={is_current_last_pp_tp_rank0}, next_dp_size={next_dp_size}, cur_dp_size={cur_dp_size}, \"\n                f\"cur_node_type={cur_node.node_type}, next_node_type={next_node.node_type}, \"\n                f\"cur_node_role={cur_node.node_role}, next_node_role={next_node.node_role}\")\n    return result\n\n\n# ==========================================================================================\n# Section 6: Metrics Collection & Aggregation\n# ==========================================================================================\n\ndef reduce_and_broadcast_metrics(\n    local_metrics: Dict[str, Union[float, List[float], torch.Tensor]],\n    group: dist.ProcessGroup\n) -> Dict[str, float]:\n    \"\"\"Aggregates metrics across all ranks using all_reduce operations.\"\"\"\n    if not isinstance(local_metrics, dict) or not local_metrics:\n        return {}\n\n    world_size = dist.get_world_size(group)\n    if world_size <= 1:\n        # Non-distributed case: perform local aggregation only\n        aggregator = DistributedMetricAggregator(local_metrics, group=None)\n        final_metrics = {}\n        for op_type, data in aggregator.op_buckets.items():\n            for key, value in data:\n                if op_type == _ReduceOp.SUM:  # value is a (sum, count) tuple\n                    final_metrics[key] = value[0] / value[1] if value[1] > 0 else 0.0\n                else:  # value is a float\n                    final_metrics[key] = float(value)\n        return final_metrics\n\n    # Pipeline Parallel: ensure all ranks have same metric keys\n    # 1. Gather all metric keys from all ranks\n    local_keys = set(local_metrics.keys())\n    all_keys_list = [None] * world_size\n    dist.all_gather_object(all_keys_list, local_keys, group=group)\n\n    # 2. Union all keys to get complete set\n    all_expected_keys = set()\n    for keys_set in all_keys_list:\n        all_expected_keys.update(keys_set)\n\n    # 3. Aggregate with unified keys\n    aggregator = DistributedMetricAggregator(local_metrics, group)\n    aggregator.op_buckets = aggregator._bucket_local_metrics(local_metrics, all_expected_keys)\n    return aggregator.aggregate_and_get_results()\n\n\ndef format_metrics_by_group(metrics: Dict[str, Any], group_order: List[str]) -> Dict[str, Any]:\n    \"\"\"Reorders metrics by group prefixes and alphabetically within groups.\"\"\"\n    if not metrics:\n        return {}\n\n    ordered_dict = {}\n    processed_keys = set()\n\n    # Pre-identify explicitly mentioned full keys\n    explicitly_mentioned_keys = {key for key in group_order if key in metrics}\n\n    # Process metrics according to group/key order\n    for pattern in group_order:\n        # Check if pattern is a full key\n        if pattern in explicitly_mentioned_keys and pattern not in processed_keys:\n            ordered_dict[pattern] = metrics[pattern]\n            processed_keys.add(pattern)\n        else:\n            # Treat as group prefix\n            group_prefix = f\"{pattern}/\"\n\n            # Find all keys in this group and sort alphabetically\n            keys_in_group = sorted(\n                [\n                    key\n                    for key in metrics\n                    if key.startswith(group_prefix)\n                    and key not in processed_keys\n                    and key not in explicitly_mentioned_keys\n                ]\n            )\n\n            for key in keys_in_group:\n                ordered_dict[key] = metrics[key]\n                processed_keys.add(key)\n\n    # Process remaining keys\n    remaining_keys = sorted([key for key in metrics if key not in processed_keys])\n    if remaining_keys:\n        for key in remaining_keys:\n            ordered_dict[key] = metrics[key]\n\n    return ordered_dict\n\n\n# ==========================================================================================\n# Section 7: Logging & Output\n# ==========================================================================================\n\ndef log_metrics_to_console(rank: int, ordered_metrics: List[Tuple[str, Any]], step: int):\n    \"\"\"Logs formatted metrics string to console (rank 0 only).\"\"\"\n    if rank != 0:\n        return\n    log_parts = [f\"step:{step}\"]\n    log_parts.extend([f\"{k}:{v:.4f}\" if isinstance(v, float) else f\"{k}:{v}\" for k, v in ordered_metrics])\n    logger.info(\" | \".join(log_parts))\n\n\ndef dump_validation_generations(\n    config,\n    global_steps: int,\n    rank: int,\n    results: List[ValidationResult]\n):\n    \"\"\"Dumps validation generation results to rank-specific JSON file.\"\"\"\n    dump_path_str = config.trainer.rollout_data_dir\n    if not dump_path_str:\n        return\n    dump_path = Path(dump_path_str)\n\n    try:\n        dump_path.mkdir(parents=True, exist_ok=True)\n\n        filename = dump_path / f\"step_{global_steps}_rank_{rank}.json\"\n\n        # Collect entries\n        entries = []\n        for res in results:\n            entry = {\n                \"rank\": rank,\n                \"global_step\": global_steps,\n                \"data_source\": res.data_source,\n                \"input\": res.input_text,\n                \"output\": res.output_text,\n                \"score\": res.score,\n            }\n            if res.extra_rewards:\n                entry.update(res.extra_rewards)\n            entries.append(entry)\n\n        # Write with pretty formatting\n        with open(filename, \"w\", encoding=\"utf-8\") as f:\n            json.dump(entries, f, ensure_ascii=False, indent=4)\n\n        if rank == 0:\n            logger.info(f\"Validation generations are being dumped by all ranks to: {dump_path.resolve()}\")\n        logger.debug(f\"Rank {rank}: Dumped {len(results)} validation generations to {filename}\")\n\n    except (OSError, IOError) as e:\n        logger.error(f\"Rank {rank}: Failed to write validation dump file to {dump_path}: {e}\")\n    except Exception as e:\n        logger.error(f\"Rank {rank}: An unexpected error occurred during validation dumping: {e}\", exc_info=True)\n\n\ndef aggregate_and_write_performance_metrics(\n    gather_group,\n    rank,\n    global_steps,\n    config,\n    metrics: Dict[str, Any]):\n    \"\"\"\n    Gathers performance metrics from all ranks to rank 0 and writes them to a CSV file.\n    Each row corresponds to a metric key COMMON to all ranks, and each column to a rank.\n    This function is called only if performance profiling is enabled.\n    \"\"\"\n    # Gather all metrics dictionaries to rank 0\n    world_size = dist.get_world_size()\n    gathered_metrics = [None] * world_size if rank == 0 else None\n    dist.gather_object(metrics, gathered_metrics, dst=0, group=gather_group)\n\n    if rank == 0:\n        if not gathered_metrics:\n            logger.warning(\"No metrics gathered on rank 0. Skipping performance CSV write.\")\n            return\n\n        valid_metrics = [m for m in gathered_metrics if isinstance(m, dict) and m]\n        if not valid_metrics:\n            logger.warning(\"No valid metric dictionaries received on rank 0. Skipping CSV write.\")\n            return\n\n        common_keys = set(valid_metrics[0].keys())\n        for rank_metrics in valid_metrics[1:]:\n            common_keys.intersection_update(rank_metrics.keys())\n\n        sorted_keys = sorted(list(common_keys))\n\n        if not sorted_keys:\n            logger.warning(\n                f\"No common metric keys found across all ranks for step {global_steps}. Skipping CSV write.\"\n            )\n            return\n\n        ts = get_time_now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n        try:\n            # Try to get model name from model path config\n            model_name = os.path.basename(os.path.normpath(config.actor_rollout_ref.model.path))\n            output_dir = os.path.join(\"performance_logs\", model_name, ts)\n            os.makedirs(output_dir, exist_ok=True)\n        except OSError as e:\n            logger.error(f\"Failed to create performance log directory {output_dir}: {e}\")\n            return\n\n        filename = os.path.join(output_dir, f\"world_{world_size}_step_{global_steps}_common_metrics.csv\")\n\n        try:\n            with open(filename, \"w\", newline=\"\", encoding=\"utf-8\") as csvfile:\n                writer = csv.writer(csvfile)\n\n                header = (\n                    [\"metric\"]\n                    + [f\"rank_{i}\" for i in range(world_size)]\n                    + [\"max\", \"min\", \"delta_max_min\", \"delta_max_rank_0\"]\n                )\n                writer.writerow(header)\n\n                for key in sorted_keys:\n                    row = [key]\n                    for i in range(world_size):\n                        rank_metrics = gathered_metrics[i]\n\n                        if isinstance(rank_metrics, dict):\n                            value = rank_metrics.get(key, \"Error: Key Missing\")\n                        else:\n                            value = \"N/A: Invalid Data\"\n                        row.append(value)\n\n                    row_max = max([x for x in row[1:] if isinstance(x, (int, float))], default=\"N/A\")\n                    row_min = min([x for x in row[1:] if isinstance(x, (int, float))], default=\"N/A\")\n                    row_delta_max = (\n                        row_max - row_min\n                        if isinstance(row_max, (int, float)) and isinstance(row_min, (int, float))\n                        else \"N/A\"\n                    )\n                    row_delta_rank0 = row_max - row[1] if isinstance(row[1], (int, float)) else \"N/A\"\n                    row.extend([row_max, row_min, row_delta_max, row_delta_rank0])\n                    writer.writerow(row)\n\n            logger.info(\n                f\"Common performance metrics for step {global_steps} successfully written to {filename}\"\n            )\n\n        except OSError as e:\n            logger.error(f\"Failed to write performance metrics to CSV file {filename}: {e}\")\n\n\ndef log_core_performance_metrics(rank: int, enable_perf: bool, metrics: Dict[str, Any], step: int):\n    \"\"\"\n    Logs a formatted, easy-to-read summary of core performance metrics on rank 0.\n    This provides a clear, separate view of the most important indicators.\n    \"\"\"\n    if rank != 0:\n        return\n\n    def get_metric(key, precision=3):\n        val = metrics.get(key)\n        if val is None:\n            return \"N/A\"\n        if isinstance(val, (float, np.floating)):\n            return f\"{val:.{precision}f}\"\n        return val\n\n    # --- Build the log string ---\n    log_str = f\"\\n\\n{'=' * 25} RANK({rank}): Core Performance Metrics (Step: {step}) {'=' * 25}\\n\"\n\n    # --- Overall Performance ---\n    log_str += \"\\n--- ⏱️  Overall Performance ---\\n\"\n    log_str += f\"  {'Step Time':<28}: {get_metric('perf/time_per_step', 3)} s\\n\"\n    log_str += f\"  {'Throughput (tokens/s)':<28}: {get_metric('perf/throughput', 2)}\\n\"\n    log_str += f\"  {'Total Tokens in Step':<28}: {get_metric('perf/total_num_tokens', 0)}\\n\"\n\n    # --- Algorithm-Specific Metrics ---\n    log_str += \"\\n--- 📈 Algorithm Metrics ---\\n\"\n    log_str += f\"  {'Actor Entropy':<28}: {get_metric('actor/entropy_loss', 4)}\\n\"\n    log_str += (\n        f\"  {'Critic Rewards (Mean/Min/Max)':<28}: {get_metric('critic/rewards/mean', 3)} / \"\n        f\"{get_metric('critic/rewards/min', 3)} / {get_metric('critic/rewards/max', 3)}\\n\"\n    )\n    log_str += (\n        f\"  {'Critic Scores (Mean/Min/Max)':<28}: {get_metric('critic/score/mean', 3)} / \"\n        f\"{get_metric('critic/score/min', 3)} / {get_metric('critic/score/max', 3)}\\n\"\n    )\n\n    if enable_perf:\n        # --- Module-wise Timings (Single Column) ---\n        log_str += \"\\n--- ⏳ Module-wise Timings (s) ---\\n\"\n        # Dynamically find all delta_time metrics except the total step time\n        timing_keys = sorted(\n            [k for k in metrics.keys() if k.startswith(\"perf/delta_time/\") and k != \"perf/delta_time/step\"]\n        )\n\n        ref_key = \"perf/delta_time/ref\"\n        reference_key = \"perf/delta_time/reference\"\n        if ref_key in timing_keys and reference_key in timing_keys:\n            timing_keys.remove(reference_key)\n\n        if timing_keys:\n            # Find the maximum label length across all keys for clean alignment\n            max_label_len = 0\n            if timing_keys:\n                max_label_len = max(\n                    len(k.replace(\"perf/delta_time/\", \"\").replace(\"_\", \" \").title()) for k in timing_keys\n                )\n\n            for key in timing_keys:\n                label = key.replace(\"perf/delta_time/\", \"\").replace(\"_\", \" \").title()\n                value = get_metric(key, 3)\n                log_str += f\"  {label:<{max_label_len}} : {value}s\\n\"\n        else:\n            log_str += \"  No detailed timing metrics available.\\n\"\n\n    # --- Model Flops Utilization (MFU) ---\n    log_str += \"\\n--- 🔥 Model Flops Utilization (MFU) ---\\n\"\n    log_str += f\"  {'Mean MFU':<28}: {get_metric('perf/mfu/mean', 3)}\\n\"\n    log_str += f\"  {'Actor Training MFU':<28}: {get_metric('perf/mfu/actor', 3)}\\n\"\n    # log_str += f\"  {'Rollout MFU':<28}: {get_metric('perf/mfu/rollout', 3)}\\n\"\n    log_str += f\"  {'Reference Policy MFU':<28}: {get_metric('perf/mfu/ref', 3)}\\n\"\n    log_str += f\"  {'Actor LogProb MFU':<28}: {get_metric('perf/mfu/actor_log_prob', 3)}\\n\"\n\n    # --- Memory Usage ---\n    log_str += \"\\n--- 💾 Memory Usage ---\\n\"\n    log_str += f\"  {'Max GPU Memory Allocated':<28}: {get_metric('perf/max_memory_allocated_gb', 2)} GB\\n\"\n    log_str += f\"  {'Max GPU Memory Reserved':<28}: {get_metric('perf/max_memory_reserved_gb', 2)} GB\\n\"\n    log_str += f\"  {'CPU Memory Used':<28}: {get_metric('perf/cpu_memory_used_gb', 2)} GB\\n\"\n\n    # --- Sequence Lengths ---\n    log_str += \"\\n--- 📏 Sequence Lengths ---\\n\"\n    log_str += (\n        f\"  {'Prompt Length (Mean/Max)':<28}: {get_metric('prompt/length/mean', 1)} / \"\n        f\"{get_metric('prompt/length/max', 0)}\\n\"\n    )\n    log_str += (\n        f\"  {'Response Length (Mean/Max)':<28}: {get_metric('response/length/mean', 1)} / \"\n        f\"{get_metric('response/length/max', 0)}\\n\"\n    )\n    log_str += f\"  {'Response Clip Ratio':<28}: {get_metric('response/clip_ratio/mean', 4)}\\n\"\n    log_str += f\"  {'Prompt Clip Ratio':<28}: {get_metric('prompt/clip_ratio/mean', 4)}\\n\"\n    log_str += (\n        f\"  {'Correct Resp Len (Mean/Max)':<28}: {get_metric('response/correct_length/mean', 1)} / \"\n        f\"{get_metric('response/correct_length/max', 0)}\\n\"\n    )\n    log_str += (\n        f\"  {'Wrong Resp Len (Mean/Max)':<28}: {get_metric('response/wrong_length/mean', 1)} / \"\n        f\"{get_metric('response/wrong_length/max', 0)}\\n\"\n    )\n\n    log_str += \"\\n\" + \"=\" * 82 + \"\\n\"\n    logger.info(log_str)\n\n\n# ==========================================================================================\n# Section 8: General Utilities\n# ==========================================================================================\n\n@staticmethod\ndef get_time_now(time_zone: str = \"Asia/Shanghai\") -> datetime:\n    \"\"\"Returns current time in specified timezone.\"\"\"\n    return datetime.now(tz=ZoneInfo(time_zone))\n\n\ndef consistent_hash(s: str) -> int:\n    \"\"\"Returns consistent hash of string using MD5.\"\"\"\n    return int(hashlib.md5(s.encode()).hexdigest(), 16)\n"
  },
  {
    "path": "siirl/dag_worker/dagworker.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\nimport uuid\nimport ray\nimport torch\nimport asyncio\nimport numpy as np\nimport torch.distributed as dist\nfrom collections import defaultdict\nfrom pprint import pformat\nfrom tqdm import tqdm\nfrom loguru import logger\nfrom typing import Any, Dict, List, Optional, Set, Tuple, Type, Callable\nfrom torch.distributed import ProcessGroup\nfrom tensordict import TensorDict\n# Handle different tensordict versions - NonTensorData location varies\ntry:\n    from tensordict import NonTensorData\nexcept ImportError:\n    from tensordict.tensorclass import NonTensorData\nimport time\nfrom siirl.execution.metric_worker.metric_worker import MetricClient\nfrom siirl.models.loader import TokenizerModule, load_tokenizer\nfrom siirl.params import SiiRLArguments\nfrom siirl.engine.base_worker import Worker\nfrom siirl.execution.dag import TaskGraph\nfrom siirl.execution.dag.node import NodeRole, NodeType, Node\nfrom siirl.execution.scheduler.reward import compute_reward, create_reward_manager\nfrom siirl.execution.scheduler.process_group_manager import ProcessGroupManager\nfrom siirl.execution.scheduler.enums import AdvantageEstimator, WorkflowType\nfrom siirl.data_coordinator import preprocess_dataloader, Samples2Dict, Dict2Samples, SampleInfo\nfrom siirl.data_coordinator.dataloader import DataLoaderNode\nfrom siirl.dag_worker.data_structures import NodeOutput\nfrom siirl.dag_worker.constants import DAGConstants, DAGInitializationError\nfrom siirl.dag_worker import core_algos\nfrom siirl.dag_worker.checkpoint_manager import CheckpointManager\nfrom siirl.dag_worker.core_algos import (\n    agg_loss,\n    apply_kl_penalty,\n    compute_advantage,\n    compute_response_mask\n    )\nfrom siirl.dag_worker.dag_utils import  (\n    log_ray_actor_info,\n    get_and_validate_rank,\n    get_taskgraph_for_rank,\n    log_role_worker_mapping,\n    should_create_worker,\n    generate_node_worker_key,\n    find_first_non_compute_ancestor,\n    setup_sharding_manager,\n    get_worker_classes,\n    get_parallelism_config,\n    prepare_generation_batch,\n    format_metrics_by_group,\n    log_metrics_to_console,\n    aggregate_and_write_performance_metrics,\n    log_core_performance_metrics,\n    timer,\n    reduce_and_broadcast_metrics,\n    whether_put_data\n    )\nfrom siirl.utils.debug import DistProfiler\nfrom siirl.utils.extras.device import get_device_name, get_nccl_backend\nfrom siirl.execution.rollout_flow.multiturn.agent_loop import AgentLoopManager\n\ndevice_name = get_device_name()\n\nclass DAGWorker(Worker):\n    \"\"\"\n    Orchestrates a Directed Acyclic Graph (DAG) of tasks for distributed training,\n    managing the setup, initialization, and workflow for a specific rank.\n    \"\"\"\n\n    def __init__(\n        self,\n        config: SiiRLArguments,\n        process_group_manager: ProcessGroupManager,\n        taskgraph_mapping: Dict[int, TaskGraph],\n        data_coordinator: \"ray.actor.ActorHandle\",\n        metric_worker: \"ray.actor.ActorHandle\",\n        device_name=\"cuda\",\n    ):\n        super().__init__()\n        self.config = config\n        self.process_group_manager = process_group_manager\n        self.taskgraph_mapping = taskgraph_mapping\n        self.data_coordinator = data_coordinator\n        self.device_name = device_name\n        self.enable_perf = os.environ.get(\"SIIRL_ENABLE_PERF\", \"0\") == \"1\" or config.dag.enable_perf\n\n        # State attributes\n        self.timing_raw = {}\n        self.global_steps = 0\n        self.total_training_steps = 0\n        self.workers: Dict[str, Any] = {}\n        self.multi_agent_group: Dict[int, Dict[NodeRole, Any]] = defaultdict(dict)\n        self.agent_group_process_group: Dict[int, Dict[NodeRole, Any]] = defaultdict(dict)\n        self.process_groups: Dict[str, ProcessGroup] = {}\n        self.tokenizer_mapping: Dict[str, TokenizerModule] = {}\n        self.logger = None\n        self.progress_bar = None\n        self._rank: int = -1\n        self.taskgraph: Optional[TaskGraph] = None\n        self.internal_data_cache: Dict[str, Any] = {}\n        self.sample_ref_cache: list = []\n        self.agent_critic_worker: Any\n        # Finish flag\n        self.taskgraph_execute_finished = False\n\n        # async rollout\n        self.rollout_mode = \"sync\"\n        self._async_rollout_manager = None\n        self.zmq_address = None  # used for async_vllmrollout\n\n        # Add a cache to hold data from an insufficient batch for the next training step.\n        # This is the core state-carrying mechanism for dynamic sampling.\n        self.sampling_leftover_cache: Optional[Any] = None\n\n        # multi agent\n        self._multi_agent = False\n        \n        # metirc_worker\n        self.metric_worker = MetricClient(metric_worker=metric_worker)\n        try:\n            self._initialize_worker()\n        except (ValueError, TypeError, KeyError, AttributeError, NotImplementedError) as e:\n            rank = os.environ.get(\"RANK\", \"UNKNOWN\")\n            logger.error(f\"Rank {rank}: Failed to create DAGWorker due to a critical setup error: {e}\", exc_info=True)\n            raise DAGInitializationError(f\"Initialization failed on Rank {rank}: {e}\") from e\n\n        log_ray_actor_info(self._rank)\n\n# ==========================================================================================\n# Module 1: Execution and Training Loop\n# ==========================================================================================\n\n    def execute_task_graph(self):\n        \"\"\"Main entry point to start the DAG execution pipeline.\"\"\"\n        logger.info(f\"Rank {self._rank}: Starting DAG execution pipeline...\")\n        logger.success(f\"Rank {self._rank}: All components initialized. Starting training loop from step {self.global_steps + 1}.\")\n\n        if self.config.trainer.val_before_train:\n            self.validator.validate(global_step=self.global_steps)\n            self.metric_worker.wait_submit()\n            dist.barrier(self._gather_group)\n            if self._rank == 0 and self.logger:\n                val_metrics = self.metric_worker.wait_final_res() \n                logger.info(f\"Initial validation metrics:\\n{pformat(val_metrics)}\")\n                self.logger.log(data=val_metrics, step=self.global_steps)\n\n            if self.config.trainer.val_only:\n                logger.info(\"`val_only` is true. Halting after initial validation.\")\n                return\n        self._run_training_loop()\n\n        if self.progress_bar:\n            self.progress_bar.close()\n        self.taskgraph_execute_finished = True\n        logger.success(f\"Rank {self._rank}: DAG execution finished.\")\n\n    def _run_training_loop(self):\n        \"\"\"\n        The main loop that iterates through training steps and epochs.\n        \"\"\"\n        self.total_training_steps = self.dataloader.total_training_steps\n        if self.dataloader.num_train_batches <= 0:\n            if self._rank == 0:\n                logger.warning(f\"num_train_batches is {self.dataloader.num_train_batches}. The training loop will be skipped.\")\n            return\n\n        if self._rank == 0:\n            self.progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc=\"Training Progress\")\n\n        last_val_metrics = None\n\n        # Calculate starting epoch and batches to skip in that epoch for resumption.\n        start_epoch = 0\n        batches_to_skip = 0\n        if self.dataloader.num_train_batches > 0:\n            start_epoch = self.global_steps // self.dataloader.num_train_batches\n            batches_to_skip = self.global_steps % self.dataloader.num_train_batches\n\n        for epoch in range(start_epoch, self.config.trainer.total_epochs):\n            is_embodied = self.config.algorithm.workflow_type == WorkflowType.EMBODIED\n            if is_embodied:\n                self._cleanup_step_buffers(self.timing_raw)\n            for batch_idx in range(self.dataloader.num_train_batches):\n                if epoch == start_epoch and batch_idx < batches_to_skip:\n                    continue\n\n                if self.global_steps >= self.total_training_steps:\n                    logger.info(f\"Rank {self._rank}: Reached total training steps. Exiting loop.\")\n                    if self._rank == 0 and last_val_metrics:\n                        logger.info(f\"Final validation metrics:\\n{pformat(last_val_metrics)}\")\n                    return\n                \n                if self.global_steps in self.config.profiler.profile_steps:\n                    self._profiler.start(role=\"e2e\", profile_step=self.global_steps)\n                    \n                ordered_metrics = self._run_training_step(epoch, batch_idx)\n                \n                if self.global_steps in self.config.profiler.profile_steps:\n                    self._profiler.stop()\n\n                if ordered_metrics is None:\n                    if self.progress_bar:\n                        self.progress_bar.update(1)\n                    continue\n\n                self.global_steps += 1\n\n                is_last_step = self.global_steps >= self.total_training_steps\n\n                if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0):\n                    self.checkpoint_manager.save_checkpoint(self.global_steps)\n\n                if self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0):\n                    self.validator.validate(global_step=self.global_steps)\n                    self.metric_worker.wait_submit()\n                    dist.barrier(self._gather_group)\n                    if self._rank == 0:\n                        val_metric = self.metric_worker.wait_final_res()\n                        ordered_metrics.update(val_metric)\n                        if is_last_step:\n                            last_val_metrics = val_metric\n\n                if self.enable_perf:\n                    aggregate_and_write_performance_metrics(self._gather_group, self._rank, self.global_steps, self.config, ordered_metrics)\n                ordered_metric_dict = format_metrics_by_group(ordered_metrics, DAGConstants.METRIC_GROUP_ORDER)\n                log_core_performance_metrics(self._rank, self.enable_perf, ordered_metric_dict, self.global_steps)\n                if self._rank == 0:\n                    if self.logger:\n                        self.logger.log(data=ordered_metric_dict, step=self.global_steps)\n                    else:\n                        log_metrics_to_console(self._rank, ordered_metric_dict, self.global_steps)\n\n                if self.progress_bar and not (epoch == start_epoch and batch_idx < batches_to_skip):\n                    self.progress_bar.update(1)\n\n        if self._rank == 0 and last_val_metrics:\n            logger.info(f\"Final validation metrics:\\n{pformat(last_val_metrics)}\")\n\n    def _cleanup_step_buffers(self, timing_raw: dict) -> None:\n        \"\"\"\n        Encapsulates the logic for resetting and clearing all step-related buffers.\n        This includes the distributed Ray data buffers and the local internal cache.\n        This is called at the end of a step, whether it completed successfully or was aborted.\n        \"\"\"\n        # Reset the distributed (Ray) buffers for all keys that were used in this step.\n        with timer(self.enable_perf, \"reset_data_buffer\", timing_raw):\n            self.reset_data_buffer()\n            for ref in self.sample_ref_cache:\n                ray.internal.free(ref)\n            self.sample_ref_cache = []\n        # Clear the local, in-process cache for the next step.\n        with timer(self.enable_perf, \"reset_intern_data_buffer\", timing_raw):\n            self.internal_data_cache.clear()\n\n    def _run_training_step(self, epoch: int, batch_idx: int) -> Optional[List[Tuple[str, Any]]]:\n        \"\"\"Executes a single training step by traversing the computational graph.\"\"\"\n        timing_raw, ordered_metrics = self.timing_raw, []\n\n        with timer(self.enable_perf, \"step\", timing_raw):\n            # --- 1. Data Loading ---\n            with timer(self.enable_perf, \"get_data_from_dataloader\", timing_raw):\n                is_embodied = self.config.actor_rollout_ref.model.model_type == \"embodied\"\n                repeat_n = self.config.actor_rollout_ref.rollout.n\n                batch = preprocess_dataloader(\n                    self.dataloader.run(epoch=epoch, is_validation_step=False),\n                    repeat_n\n                )\n            node_queue = self.taskgraph.get_entry_nodes()\n            if not node_queue:\n                logger.error(\"Taskgraph has no entry nodes. Cannot start execution.\")\n                return None\n            entry_node_id = node_queue[0].node_id\n\n            # --- 2. Graph Traversal ---\n            visited_nodes = set()\n            with timer(self.enable_perf, \"graph_execution\", timing_raw):\n                while node_queue:\n                    cur_node = node_queue.pop(0)\n                    if cur_node.node_id in visited_nodes:\n                        continue\n                    visited_nodes.add(cur_node.node_id)\n\n                    cur_dp_size, cur_dp_rank, cur_tp_rank, cur_tp_size, cur_pp_rank, cur_pp_size = self._get_node_dp_info(cur_node)\n                    logger.debug(f\"current node({cur_node.node_id}) dp_size: {cur_dp_size}, dp_rank: {cur_dp_rank}, tp_rank: {cur_tp_rank}, pp_rank: {cur_pp_rank}, pp_size: {cur_pp_size}\")\n\n                    # --- 3. Get Input Data ---\n                    if cur_node.node_id != entry_node_id:\n                        with timer(self.enable_perf, \"get_data_from_buffer\", timing_raw):\n                            batch = self.get_data_from_buffers(key=cur_node.node_id, cur_dp_size=cur_dp_size, cur_dp_rank=cur_dp_rank, timing_raw=timing_raw)\n                            if batch is None:\n                                embodied_sampling = self.config.algorithm.embodied_sampling\n                                allow_insufficient = (\n                                    self.config.algorithm.filter_groups.enable\n                                    or embodied_sampling.filter_accuracy\n                                    or embodied_sampling.filter_truncated\n                                )\n                                if allow_insufficient:\n                                    # Dynamic sampling scenario - waiting for data is expected behavior\n                                    if cur_node.node_role == NodeRole.ACTOR:\n                                        logger.debug(f\"Rank {self._rank}: Waiting for sufficient data for node {cur_node.node_id}. Skipping this step.\")\n                                        return None \n                                else:\n                                    logger.error(f\"Rank {self._rank}: Failed to get data for node {cur_node.node_id}. Skipping step.\")\n                                    return None \n                            else:\n                                # batch = remove_prefix_from_dataproto(batch, cur_node)\n                                logger.debug(f\"current node({cur_node.node_id}) get data from databuffer batch size: {batch.size()}\")\n                    if self.enable_perf:\n                        with timer(self.enable_perf, \"get_data_from_buffer_barrier\", timing_raw):\n                            dist.barrier(self._gather_group)\n                    # --- 4. Node Execution ---\n                    node_name_timer = f\"{cur_node.node_id}\"\n                    with timer(self.enable_perf, node_name_timer, timing_raw):\n                        if cur_node.executable and batch is not None:\n                            node_kwargs = {\"_dag_worker_instance\": self}\n                            node_kwargs[\"process_group\"] = self._get_node_process_group(cur_node) if cur_node.node_type != NodeType.COMPUTE else None\n                            node_kwargs[\"agent_group\"] = self.multi_agent_group[cur_node.agent_group]\n                            node_kwargs[\"cur_tp_rank\"] = cur_tp_rank\n                            if cur_node.node_role == NodeRole.REWARD:\n                                node_kwargs[\"tp_size\"] = cur_tp_size\n                                # Add parallelism info to batch for distributed reward computation\n                                batch[\"dp_size\"] = NonTensorData(cur_dp_size)\n                                batch[\"dp_rank\"] = NonTensorData(cur_dp_rank)\n                                batch[\"tp_rank\"] = NonTensorData(cur_tp_rank)\n                                batch[\"tp_size\"] = NonTensorData(cur_tp_size)\n                                batch[\"pp_rank\"] = NonTensorData(cur_pp_rank)\n                                batch[\"pp_size\"] = NonTensorData(cur_pp_size)\n                            elif cur_node.node_role == NodeRole.ADVANTAGE:\n                                node_kwargs[\"cur_node\"] = cur_node\n\n                            if cur_node.agent_options and cur_node.agent_options.train_cycle:\n                                cycle_round = self.global_steps // cur_node.agent_options.train_cycle\n                                agent_num = len(self.multi_agent_group)\n                                if cycle_round % agent_num == cur_node.agent_group:\n                                    node_output = cur_node.run(batch=batch,\n                                                               config=self.config,\n                                                               **node_kwargs)\n                                else:\n                                    node_output = NodeOutput(batch=batch)\n                            else:\n                                node_output = cur_node.run(batch=batch,\n                                                           config=self.config,\n                                                           **node_kwargs)\n                        else:\n                            logger.warning(f\"Node {cur_node.node_id} has no executable. Passing data through.\")\n                            node_output = NodeOutput(batch=batch)\n                    \n                    # Check if node returned empty batch (e.g., DAPO insufficient samples)\n                    # This triggers re-rollout to collect more data\n                    if node_output.batch is None or (node_output.batch is not None and len(node_output.batch) == 0):\n                        logger.warning(\n                            f\"Rank {self._rank}: Node '{cur_node.node_id}' returned empty batch. \"\n                        )\n                        embodied_sampling = self.config.algorithm.embodied_sampling\n                        allow_insufficient = (\n                            self.config.algorithm.filter_groups.enable\n                            or embodied_sampling.filter_accuracy\n                            or embodied_sampling.filter_truncated\n                        )\n                        if not allow_insufficient:\n                            logger.warning(\n                                f\"Rank {self._rank}: Node '{cur_node.node_id}' returned empty batch. \"\n                                f\"Aborting current step to trigger re-rollout. {node_output.batch is not None and len(node_output.batch) != 0}\"\n                            )\n                            return None\n                    \n                    if self.enable_perf:        \n                        with timer(self.enable_perf, f\"{node_name_timer}_barrier\", timing_raw):\n                            dist.barrier(self._gather_group)\n                    if cur_node.node_role == NodeRole.ROLLOUT and self._multi_agent:\n                        next_nodes = self.taskgraph.get_downstream_nodes(cur_node.node_id)\n                        while next_nodes[0].node_role == NodeRole.ROLLOUT:\n                            cur_node = next_nodes[0]\n                            next_nodes = self.taskgraph.get_downstream_nodes(cur_node.node_id)\n\n                    # --- 5. Process Output & Get next node ---\n                    with timer(self.enable_perf, \"graph_output_handling\", timing_raw):\n                        if node_output.metrics is not None and len(node_output.metrics) > 0 and cur_tp_rank == 0 and cur_pp_rank == 0:\n                            self.metric_worker.submit_metric(node_output.metrics, cur_dp_size)\n                        if next_nodes := self.taskgraph.get_downstream_nodes(cur_node.node_id):\n                            if node_output.batch is not None and len(node_output.batch) != 0:\n                                # Currently supports single downstream node, can be extended to a loop.\n                                next_node = next_nodes[0]\n                                next_dp_size, _, _, _, _, _ = self._get_node_dp_info(next_node)\n                                # node_output.batch = add_prefix_to_dataproto(node_output.batch, cur_node)\n                                is_current_last_pp_tp_rank0 = (cur_pp_rank == cur_pp_size - 1 and cur_tp_rank == 0)\n                                if whether_put_data(self._rank, is_current_last_pp_tp_rank0, next_dp_size, cur_dp_size, cur_node, next_node):\n                                    with timer(self.enable_perf, \"put_data_to_buffer\", timing_raw):\n                                        # Determine if we need to force data through DataCoordinator\n                                        # This is needed when filter causes data imbalance and requires rebalancing\n                                        embodied_sampling = self.config.algorithm.embodied_sampling\n                                        \n                                        # Check if any filtering is enabled (causes data imbalance)\n                                        has_filtering = (\n                                            self.config.algorithm.filter_groups.enable\n                                            or embodied_sampling.filter_accuracy\n                                            or embodied_sampling.filter_truncated\n                                        )\n                                        \n                                        # Check if current node is embodied filter node\n                                        is_embodied_filter_node = (cur_node.node_id == \"embodied_sampling\")\n                                        \n                                        # Check if this is a COMPUTE -> consumer transition that needs rebalancing\n                                        is_compute_output = (cur_node.node_type == NodeType.COMPUTE)\n                                        needs_rebalance = (\n                                            next_node.node_type == NodeType.MODEL_TRAIN\n                                            or (is_embodied_filter_node and next_node.node_role == NodeRole.REWARD)\n                                        )\n                                        \n                                        enforce_buffer = has_filtering and is_compute_output and needs_rebalance\n                                        \n                                        self.put_data_to_buffers(key=next_node.node_id, data=node_output.batch, source_dp_size=cur_dp_size, dest_dp_size=next_dp_size, enforce_buffer=enforce_buffer, timing_raw=timing_raw)\n                        # elif self._multi_agent:\n                        #     # last_node add prefix for metrics\n                        #     node_output.batch = add_prefix_to_dataproto(node_output.batch, cur_node)                        \n                        if self.enable_perf:\n                            with timer(self.enable_perf, \"put_data_to_buffer_barrier\", timing_raw):\n                                dist.barrier(self._gather_group)\n                        with timer(self.enable_perf, \"get_next_node\", timing_raw):\n                            for n in next_nodes:\n                                if n.node_id not in visited_nodes:\n                                    node_queue.append(n)\n\n                    with timer(self.enable_perf, \"step_barrier\", timing_raw):\n                        dist.barrier(self._gather_group)\n\n            # --- 6. Final Metrics Collection ---\n            self._cleanup_step_buffers(timing_raw)\n\n        ordered_metrics = {}\n        if cur_tp_rank == 0 and cur_pp_rank == 0:\n            self.metric_worker.compute_local_data_metric(batch, cur_dp_size)\n            self.metric_worker.compute_local_throughout_metrics(batch, timing_raw, cur_pp_size * cur_tp_size , cur_dp_size)\n            if self._rank == 0:\n                # only use rank0 time metrics\n                self.metric_worker.compute_local_timing_metrics(batch, timing_raw, 1)  \n        timing_raw.clear()\n        self.metric_worker.wait_submit()\n        dist.barrier(self._gather_group)\n        if self._rank == 0:\n            metrics = self.metric_worker.wait_final_res()\n            ordered_metrics = dict(sorted(metrics.items()))\n            ordered_metrics.update({\"training/global_step\": self.global_steps + 1, \"training/epoch\": epoch + 1})\n\n        return ordered_metrics\n\n# ==========================================================================================\n# Module 2: Graph Node Execution Handlers\n# ==========================================================================================\n\n    @DistProfiler.annotate(role=\"generate\")\n    def generate_sync_mode(self, agent_group, batch: TensorDict) -> NodeOutput:\n        \"\"\"Sync mode\"\"\"\n        gen_output = agent_group[NodeRole.ROLLOUT].generate_sequences(batch)\n        if \"response_mask\" not in batch:\n            gen_output[\"response_mask\"] = compute_response_mask(gen_output)\n        batch = batch.update(gen_output)\n        return NodeOutput(batch=batch, metrics=gen_output[\"metrics\"])\n\n    @DistProfiler.annotate(role=\"generate\")\n    def generate_async_mode(self, batch: TensorDict) -> NodeOutput:\n        \"\"\"Async mode\"\"\"\n        if self._async_rollout_manager is not None:\n            loop = asyncio.get_event_loop()\n            gen_output = loop.run_until_complete(self._async_rollout_manager.generate_sequences(batch))\n            metrics = gen_output[\"metrics\"]\n            if \"response_mask\" not in batch:\n                batch[\"response_mask\"] = compute_response_mask(batch)\n            return NodeOutput(batch=batch, metrics=metrics)\n        return NodeOutput(batch=batch, metrics={})\n\n    @DistProfiler.annotate(role=\"generate\")\n    def generate_multi_agent_mode(self, config, batch: TensorDict) -> NodeOutput:\n        \"\"\"Generates sequences for a training batch using the multi-agent rollout model.\"\"\"\n        gen_batch = prepare_generation_batch(batch)\n        if config.actor_rollout_ref.rollout.agent.rewards_with_env and \"reward_model\" in batch.non_tensor_batch:\n            gen_batch.non_tensor_batch[\"reward_model\"] = batch.non_tensor_batch[\"reward_model\"]\n        assert config.actor_rollout_ref.rollout.name == 'sglang'\n        gen_output = self.multi_agent_loop.generate_sequence(gen_batch)\n        if gen_output:\n            metrics = gen_output.meta_info.get(\"metrics\", {})\n            # gen_output.meta_info = {}\n            # batch.non_tensor_batch[\"uid\"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))])\n            # batch = batch.repeat(config.actor_rollout_ref.rollout.n, interleave=True).union(gen_output)\n            # if \"response_mask\" not in batch.batch:\n            #     batch.batch[\"response_mask\"] = compute_response_mask(batch)\n            return NodeOutput(batch=gen_output, metrics=metrics)\n        return NodeOutput(batch=batch, metrics={})\n\n    @DistProfiler.annotate(role=\"generate\")\n    def generate_embodied_mode(self, agent_group, batch: TensorDict, **kwargs) -> NodeOutput:\n        \"\"\"\n        Generates embodied episodes for training.\n        \n        This method follows the same pattern as _generate_for_embodied_validation in validation_mixin,\n        but configured for training mode (do_sample=True, validate=False).\n        \n        For embodied tasks, the batch contains task metadata (task_id, trial_id, etc.) from the dataloader.\n        The rollout worker interacts with the environment and generates all required data\n        (input_ids, pixel_values, responses, etc.) during environment rollout.\n        \n        Unlike text generation, we do NOT call _prepare_generation_batch because:\n        1. The input batch doesn't have text-generation keys (input_ids, attention_mask, etc.)\n        2. These keys will be generated by the embodied rollout worker during env interaction\n        \"\"\"\n        from loguru import logger\n        \n        rollout_worker = agent_group[NodeRole.ROLLOUT]\n        rollout_n = self.config.actor_rollout_ref.rollout.n\n        \n        # Set meta_info for embodied training\n        batch[\"eos_token_id\"] = NonTensorData(self.validate_tokenizer.eos_token_id if self.validate_tokenizer else None)\n        batch[\"n_samples\"] = NonTensorData(self.config.actor_rollout_ref.rollout.n)\n        batch[\"pad_token_id\"] = NonTensorData(self.validate_tokenizer.pad_token_id if self.validate_tokenizer else None)        \n        logger.info(\n            f\"[Embodied Validation] Batch variables: \"\n            f\"{batch.batch_size[0]}, \"\n            f\"eos_token_id={batch['eos_token_id']}, \"\n            f\"pad_token_id={batch['pad_token_id']}, \"\n            f\"n_samples={batch['n_samples']} (dataloader already repeated {rollout_n}x), \"\n        )\n        # Generate embodied episodes\n        gen_output = rollout_worker.generate_sequences(batch)\n        # Extract metrics (may be wrapped in NonTensorData)\n        raw_metrics = gen_output.get(\"metrics\", {}) if hasattr(gen_output, \"get\") else {}\n        metrics = raw_metrics.data if hasattr(raw_metrics, 'data') else (raw_metrics if isinstance(raw_metrics, dict) else {})\n\n        # Merge generated data into batch\n        batch.update(gen_output)\n        \n        # Compute response mask if not already present\n        if \"response_mask\" not in batch:\n            batch[\"response_mask\"] = compute_response_mask(batch)\n\n        return NodeOutput(batch=batch, metrics=metrics)\n    \n    \n    \n    @DistProfiler.annotate(role=\"generate\")\n    def generate(self, config, batch: TensorDict, **kwargs) -> NodeOutput:\n        \"\"\"Generates sequences for a training batch using the rollout model.\"\"\"\n        # Check if this is embodied mode\n        agent_group = kwargs.pop(\"agent_group\")\n        is_embodied = self.config.actor_rollout_ref.model.model_type == \"embodied\"\n        \n        if is_embodied:\n            # Use dedicated embodied generation path (mirrors validation logic)\n            return self.generate_embodied_mode(agent_group, batch, **kwargs)\n        if self._multi_agent is False:\n            if self.rollout_mode == 'sync':\n                return self.generate_sync_mode(agent_group, batch)\n            else:\n                return self.generate_async_mode(batch)\n        else:\n            return self.generate_multi_agent_mode(config, batch)\n\n    @DistProfiler.annotate(role=\"compute_reward\")\n    def compute_reward(self, config, batch: TensorDict, **kwargs) -> NodeOutput:\n        \"\"\"Calculates rewards for a batch of generated sequences.\"\"\"\n        from loguru import logger\n        \n        if not self.check_mode() and kwargs[\"cur_tp_rank\"] != 0:\n            return NodeOutput(batch=batch, metrics={})\n        \n        tp_size = kwargs.pop(\"tp_size\")\n        if \"token_level_rewards\" in batch and batch[\"token_level_rewards\"].numel() > 0:\n            return NodeOutput(batch=batch, metrics={})\n        batch[\"global_token_num\"] = NonTensorData((torch.sum(batch[\"attention_mask\"], dim=-1) // tp_size).tolist())\n\n        reward_tensor, extra_infos = compute_reward(batch, self.reward_fn)\n        batch[\"token_level_scores\"] = reward_tensor\n\n        if extra_infos:\n            batch.update({k: np.array(v) for k, v in extra_infos.items()}, inplace=True)\n\n        metrics = {}\n        if config.algorithm.use_kl_in_reward:\n            kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)\n            batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl_in_reward, config.algorithm.kl_penalty)\n            metrics.update(kl_metrics)\n        else:\n            batch[\"token_level_rewards\"] = batch[\"token_level_scores\"]\n        return NodeOutput(batch=batch, metrics=metrics)\n\n    \n    @DistProfiler.annotate(role=\"compute_old_log_prob\")\n    def compute_old_log_prob(self, config, batch: TensorDict, **kwargs) -> NodeOutput:\n        \"\"\"Computes log probabilities from the actor model before the policy update.\"\"\"\n        process_group = kwargs.pop(\"process_group\")\n        agent_group = kwargs.pop(\"agent_group\")\n        if \"global_token_num\" not in batch:\n            # in multi-agent, agentA may don't have reward node\n            # insert some info needed\n            batch[\"global_token_num\"] = NonTensorData(torch.sum(batch[\"attention_mask\"], dim=-1).tolist())\n        processed_data = agent_group[NodeRole.ACTOR].compute_log_prob(batch)\n        local_metrics = processed_data[\"metrics\"]  if \"metrics\" in processed_data else {}\n        if \"entropys\" in processed_data:\n            entropy = agg_loss(processed_data[\"entropys\"], processed_data[\"response_mask\"].to(\"cpu\"), config.actor_rollout_ref.actor.loss_agg_mode)\n            local_metrics[\"actor/entropy_loss\"] = entropy.item()\n\n        processed_data.pop(\"metrics\", None)\n        processed_data.pop(\"entropys\", None)\n\n        return NodeOutput(batch=processed_data, metrics=local_metrics)\n\n    @DistProfiler.annotate(role=\"compute_ref_log_prob\")\n    def compute_ref_log_prob(self, config, batch: TensorDict, **kwargs) -> NodeOutput:\n        \"\"\"Computes log probabilities from the frozen reference model.\"\"\"\n        agent_group = kwargs.pop(\"agent_group\")\n        processed_data = agent_group[NodeRole.REFERENCE].compute_ref_log_prob(batch)\n        metrics = processed_data[\"metrics\"]\n        return NodeOutput(batch=processed_data, metrics=metrics)\n\n    @DistProfiler.annotate(role=\"compute_value\")\n    def compute_value(self, config, batch: TensorDict, **kwargs) -> NodeOutput:\n        \"\"\"Computes value estimates from the critic model.\"\"\"\n        agent_group = kwargs.pop(\"agent_group\")\n        processed_data = agent_group[NodeRole.CRITIC].compute_values(batch)\n        return NodeOutput(batch=processed_data)\n\n    @DistProfiler.annotate(role=\"compute_advantage\")\n    def compute_multi_agent_advantage(self, config, batch: TensorDict, **kwargs) -> NodeOutput:\n        adv_config = config.algorithm\n        rollout_config = config.actor_rollout_ref.rollout\n        cur_node = kwargs[\"cur_node\"]\n        if \"token_level_rewards\" not in batch.batch :\n            # make sure rewards of angentB has been compute\n            # GAE_MARFT adv need make sure only last agent has adv node\n            if depend_nodes := self.taskgraph.get_dependencies(cur_node.node_id):\n                depend_node = depend_nodes[0]\n                if adv_config.share_reward_in_agent:\n                    batch.batch[\"token_level_rewards\"] = batch.batch[f\"agent_group_{depend_node.agent_group}_token_level_rewards\"].clone()\n                else:\n                    batch.batch[\"token_level_rewards\"] = torch.zeros_like(batch.batch[f\"agent_group_{depend_node.agent_group}_token_level_rewards\"])\n                batch.batch[\"token_level_scores\"] = batch.batch[f\"agent_group_{depend_node.agent_group}_token_level_scores\"].clone()\n            else:\n                raise RuntimeError(f\"cur_node {cur_node.node_id} have no rewards with can't find it's dependencies reward\")\n        if adv_config.adv_estimator == AdvantageEstimator.GAE_MARFT:\n            # make sure adv node define in last agent node\n            cur_agent_id = len(self.multi_agent_group) - 1\n            agent_groups_ids = list(range(cur_agent_id))\n            kwargs[\"agent_group_ids\"] = agent_groups_ids\n            # pre_agent may have no reward token\n            for agent_id in reversed(agent_groups_ids):\n                key_prefix = f\"agent_group_{agent_id}_token_level_rewards\"\n                if key_prefix not in batch.batch:\n                    pre_key_prefix = f\"agent_group_{agent_id + 1}_token_level_rewards\" if agent_id != cur_agent_id -1 else \"token_level_rewards\"\n                    if adv_config.share_reward_in_agent:\n                        batch.batch[key_prefix] = batch.batch[pre_key_prefix].clone()\n                    else:\n                        batch.batch[key_prefix] = torch.zeros_like(batch.batch[pre_key_prefix])\n                batch.batch[f\"agent_group_{agent_id}_token_level_scores\"] = batch.batch[key_prefix].clone()\n\n        return NodeOutput(\n            batch=compute_advantage(\n                batch,\n                adv_estimator=adv_config.adv_estimator,\n                gamma=adv_config.gamma,\n                lam=adv_config.lam,\n                num_repeat=rollout_config.n,\n                norm_adv_by_std_in_grpo=adv_config.norm_adv_by_std_in_grpo,\n                weight_factor_in_cpgd=adv_config.weight_factor_in_cpgd,\n                multi_turn=rollout_config.multi_turn.enable,\n                **kwargs\n            )\n        )\n\n    @DistProfiler.annotate(role=\"compute_advantage\")\n    def compute_advantage(self, config, batch: TensorDict, **kwargs) -> NodeOutput:\n        \"\"\"Computes advantages and returns for PPO using GAE.\"\"\"\n        \n        if not self.check_mode() and kwargs[\"cur_tp_rank\"] != 0:\n            return NodeOutput(batch=batch, metrics={})\n        \n        if self._multi_agent:\n            return self.compute_multi_agent_advantage(config, batch, **kwargs)\n        algo_config = config.algorithm\n        return NodeOutput(\n            batch=compute_advantage(\n                batch,\n                adv_estimator=algo_config.adv_estimator,\n                gamma=algo_config.gamma,\n                lam=algo_config.lam,\n                norm_adv_by_std_in_grpo=algo_config.norm_adv_by_std_in_grpo,\n                weight_factor_in_cpgd=algo_config.weight_factor_in_cpgd,\n                **kwargs\n            )\n        )\n\n    @DistProfiler.annotate(role=\"train_critic\")\n    def train_critic(self, config, batch: TensorDict, **kwargs) -> NodeOutput:\n        \"\"\"Performs a single training step on the critic model.\"\"\"\n        agent_group = kwargs.pop(\"agent_group\")\n        process_group = kwargs.pop(\"process_group\")\n        processed_data = agent_group[NodeRole.CRITIC].update_critic(batch)\n        return NodeOutput(batch=processed_data, metrics=processed_data[\"metrics\"])\n\n    @DistProfiler.annotate(role=\"train_actor\")\n    def train_actor(self, config, batch: TensorDict, **kwargs) -> NodeOutput:\n        \"\"\"Performs a single training step on the actor (policy) model.\"\"\"\n        process_group = kwargs.pop(\"process_group\")\n        agent_group = kwargs.pop(\"agent_group\")\n        global_steps = batch[\"global_steps\"] if \"global_steps\" in batch else 0\n        if config.trainer.critic_warmup > global_steps:\n            return NodeOutput(batch=batch)  # Skip actor update during critic warmup\n        batch[\"multi_turn\"] = NonTensorData(self.config.actor_rollout_ref.rollout.multi_turn.enable)\n        processed_data = agent_group[NodeRole.ACTOR].update_actor(batch)\n        return NodeOutput(batch=processed_data, metrics=processed_data[\"metrics\"])\n\n\n# ==========================================================================================\n# Module 3: Worker and Environment Initialization\n# ==========================================================================================\n\n    def _initialize_worker(self):\n        \"\"\"Orchestrates the ordered initialization of all worker components.\"\"\"\n        self._rank = get_and_validate_rank()\n        self.taskgraph = get_taskgraph_for_rank(self._rank, self.taskgraph_mapping)\n\n        self._setup_distributed_environment()\n        self._setup_tokenizers()\n        self._setup_dataloader()\n        self._setup_reward_managers()\n        self._setup_role_worker_mapping()\n        self._initialize_node_workers()\n        self._profiler = DistProfiler(rank=self._rank, config=self.config.profiler)\n\n        # Initialize CheckpointManager - Note: will be fully initialized after workers are created\n        self.checkpoint_manager = None\n\n        # Initialize Validator - Note: will be initialized in init_graph() after all workers are ready\n        self.validator = None\n\n        # Initialize MetricsCollector - Note: will be initialized in init_graph() after all dependencies are ready\n        self.metrics_collector = None\n\n        if self._rank == 0:\n            logger.info(\"Rank 0: Initializing tracking logger...\")\n            from siirl.utils.logger.tracking import Tracking\n\n            self.logger = Tracking(\n                project_name=self.config.trainer.project_name,\n                experiment_name=self.config.trainer.experiment_name,\n                default_backend=self.config.trainer.logger,\n                config=self.config.to_dict(),\n            )\n            if self.enable_perf:\n                logger.warning(\"Performance tracking is enabled. This may impact training speed.\")\n\n    def _setup_distributed_environment(self):\n        \"\"\"Initializes the default process group and all required subgroups.\"\"\"\n\n        if not dist.is_initialized():\n            backend = (\n                f\"{get_nccl_backend()}\"\n                if self.world_size >= self.config.dag.backend_threshold\n                else f\"cpu:gloo,{get_device_name()}:{get_nccl_backend()}\"\n            )\n            logger.info(\n                f\"Rank {self._rank}: Initializing world size {self.world_size} default process group with '{backend}' \"\n                f\"backend.\"\n            )\n            dist.init_process_group(backend=backend)\n\n        if device_name == \"npu\":\n            # For NPU, metrics aggregation requires the hccl backend for device-to-device communication.\n            # This group is created regardless of world size for NPU environments.\n            gather_backend = get_nccl_backend()\n            self._gather_group = dist.new_group(backend=gather_backend)\n        else:\n            # For GPU, the original logic is preserved for backward compatibility.\n            # The gather group is only created if world_size < backend_threshold.\n            self._gather_group = dist.new_group(\n                backend=\"gloo\") if self.world_size < self.config.dag.backend_threshold else None\n\n        group_specs = self.process_group_manager.get_all_specs()\n        if not group_specs:\n            logger.warning(\"No process group specifications found in ProcessGroupManager.\")\n            return\n\n        #Builds all process groups defined in the ProcessGroupManager.\n        for name, spec in group_specs.items():\n            if not isinstance(spec, dict) or not (ranks := spec.get(\"ranks\")):\n                logger.warning(f\"Skipping group '{name}' due to invalid spec or missing 'ranks'.\")\n                continue\n            self.process_groups[name] = dist.new_group(ranks=ranks)\n        logger.debug(f\"Rank {self._rank}: Created {len(self.process_groups)} custom process groups.\")\n\n        self.inference_group_name_set = self.process_group_manager.get_process_group_for_node_type_in_subgraph(\n            self.taskgraph.graph_id, NodeType.MODEL_INFERENCE.value\n        )\n        self.train_group_name_set = self.process_group_manager.get_process_group_for_node_type_in_subgraph(\n            self.taskgraph.graph_id, NodeType.MODEL_TRAIN.value\n        )\n\n        # Ensure all ranks have finished group creation before proceeding.\n        dist.barrier(self._gather_group)\n        logger.info(f\"Rank {self._rank}: Distributed environment setup complete.\")\n\n    def _setup_tokenizers(self):\n        \"\"\"Initializes and caches tokenizers for all models in the task graph.\"\"\"\n        model_nodes = [\n            node\n            for node in self.taskgraph.nodes.values()\n            if node.node_type in [NodeType.MODEL_TRAIN, NodeType.MODEL_INFERENCE]\n        ]\n        if not model_nodes:\n            logger.warning(\"No model nodes found in the task graph. Tokenizer setup will be skipped.\")\n            return\n\n        for node in model_nodes:\n            agent_key = f\"group_key_{node.agent_group}\"\n            if agent_key not in self.tokenizer_mapping:\n                # Add robust check for missing configuration.\n                intern_config = node.config.get(DAGConstants.INTERN_CONFIG)\n                if not intern_config or not (model_dict := getattr(intern_config, \"model\", None)):\n                    logger.warning(f\"Node {node.node_id} is missing model config. Skipping tokenizer setup for it.\")\n                    continue\n\n                tokenizer_module = load_tokenizer(model_args=model_dict)\n                if tokenizer := tokenizer_module.get(\"tokenizer\"):\n                    tokenizer.padding_side = \"left\"  # Required for most causal LM generation\n                self.tokenizer_mapping[agent_key] = tokenizer_module\n        logger.info(f\"Rank {self._rank}: Initialized {len(self.tokenizer_mapping)} tokenizer(s).\")\n\n    def _setup_dataloader(self):\n        \"\"\"Initializes the data loader for training and validation.\"\"\"\n        rollout_nodes = [n for n in self.taskgraph.nodes.values() if n.node_type == NodeType.MODEL_INFERENCE]\n        if not rollout_nodes:\n            raise ValueError(\"At least one MODEL_INFERENCE node is required for dataloader setup.\")\n        self.first_rollout_node = rollout_nodes[0]\n\n        pg_assignment = self.process_group_manager.get_node_assignment(self.first_rollout_node.node_id)\n        if not (process_group_name := pg_assignment.get(\"process_group_name\")):\n            raise ValueError(\n                f\"Process group name not found for the first rollout node {self.first_rollout_node.node_id}.\"\n            )\n\n        self.dataloader_process_group = self.process_groups.get(process_group_name)\n        if self.dataloader_process_group is None:\n            raise ValueError(f\"Could not find process group '{process_group_name}' in the created groups.\")\n\n        self.dataloader_tensor_model_parallel_size = self.first_rollout_node.config[\n            DAGConstants.INTERN_CONFIG\n        ].rollout.tensor_model_parallel_size\n\n        self.dataloader = DataLoaderNode(\n            node_id=\"dataloader\",\n            global_config=self.config,\n            config={\n                \"group_world_size\": dist.get_world_size(self.dataloader_process_group),\n                \"group_rank\": dist.get_rank(self.dataloader_process_group),\n                \"group_parallel_size\": self.dataloader_tensor_model_parallel_size,\n                \"num_loader_workers\": self.config.data.num_loader_workers,\n                \"auto_repeat\": self.config.data.auto_repeat,\n            },\n        )\n        logger.info(f\"Rank {self._rank}: DataLoader initialized with {self.dataloader.total_training_steps} total training steps.\")\n\n    def _setup_reward_managers(self):\n        \"\"\"Initializes reward managers for training and validation.\"\"\"\n        self.validate_tokenizer = next(iter(self.tokenizer_mapping.values()), {}).get(\"tokenizer\")\n        if not self.validate_tokenizer:\n            logger.warning(\"No tokenizer loaded; reward functions might fail or use a default one.\")\n\n        self.reward_fn = create_reward_manager(\n            self.config,\n            self.validate_tokenizer,\n            num_examine=0,\n            max_resp_len=self.config.data.max_response_length,\n            overlong_buffer_cfg=self.config.reward_model.overlong_buffer,\n            **self.config.reward_model.reward_kwargs,\n        )\n        logger.info(f\"Rank {self._rank}: Reward managers initialized.\")\n\n    def _setup_role_worker_mapping(self):\n        \"\"\"Creates a mapping from NodeRole to the corresponding Worker implementation class.\"\"\"\n        self.role_worker_mapping: Dict[NodeRole, Type[Worker]] = {}\n        # Actor/Ref/Rollout/Critic workers\n        actor_strategy = self.config.actor_rollout_ref.actor.strategy\n        self.role_worker_mapping.update(get_worker_classes(self.config, actor_strategy))\n\n        # Reward model worker (if enabled)\n        if self.config.reward_model.enable:\n            reward_strategy = self.config.reward_model.strategy\n            reward_workers = get_worker_classes(self.config, reward_strategy)\n            if NodeRole.REWARD in reward_workers:\n                self.role_worker_mapping[NodeRole.REWARD] = reward_workers[NodeRole.REWARD]\n            else:\n                logger.warning(\n                    f\"Reward model is enabled, but no worker found for role REWARD with strategy {reward_strategy}.\"\n                )\n\n        log_role_worker_mapping(self.role_worker_mapping)\n\n    def _initialize_node_workers(self):\n        \"\"\"Instantiates worker objects for all nodes in the task graph.\"\"\"\n        for node in self.taskgraph.nodes.values():\n            if not should_create_worker(self.role_worker_mapping, node):\n                continue\n\n            worker_cls = self.role_worker_mapping.get(node.node_role)\n            if not worker_cls:\n                logger.warning(f\"No worker class found for role {node.node_role.name}. Skipping node {node.node_id}.\")\n                continue\n\n            node_worker_key = generate_node_worker_key(node)\n            if node_worker_key in self.workers:\n                continue\n\n            try:\n                node_process_group = self._get_node_process_group(node)\n                config = node.config.get(DAGConstants.INTERN_CONFIG)\n                if hasattr(config, \"actor\") and hasattr(config.actor, \"optim\"):\n                    config.actor.optim.total_training_steps = self.dataloader.total_training_steps\n                elif hasattr(config, \"optim\"):\n                    config.optim.total_training_steps = self.dataloader.total_training_steps\n                worker_args = {\"config\": config, \"process_group\": node_process_group}\n\n                # For separated workers (Megatron backend), no role parameter is needed\n                # Only legacy ActorRolloutRefWorker needs the role parameter\n                if hasattr(worker_cls, '__name__') and 'ActorRolloutRefWorker' in worker_cls.__name__:\n                    if node.node_role in DAGConstants.WORKER_ROLE_MAPPING:\n                        worker_args[\"role\"] = DAGConstants.WORKER_ROLE_MAPPING[node.node_role]\n                if node.agent_options and node.agent_options.share_instance:\n                    # cur agent share same critic with target agent\n                    self.multi_agent_group[node.agent_group][node.node_role] = self.multi_agent_group[node.agent_options.share_instance][node.node_role]\n                else:\n                    worker_instance = worker_cls(**worker_args)\n                    self.workers[node_worker_key] = worker_instance\n                    self.multi_agent_group[node.agent_group][node.node_role] = worker_instance\n                    self.agent_group_process_group[node.agent_group][node.node_role] = node_process_group\n                    logger.success(\n                        f\"Rank {self._rank}: Successfully created worker '{worker_cls.__name__}' for node: {node.node_id}\"\n                    )\n\n            except Exception as e:\n                #  Explicitly log the failing node and worker class, then re-raise\n                # the exception to prevent silent failures.\n                logger.error(\n                    f\"Failed to create worker for node {node.node_id} with class {worker_cls.__name__}.\", exc_info=True\n                )\n                raise RuntimeError(f\"Worker instantiation failed for node {node.node_id}\") from e\n\n        if len(self.multi_agent_group) > 1:\n            self._multi_agent = True\n\n    def init_graph(self):\n        \"\"\"\n        Initializes the computation graph by loading models and restoring checkpoint state.\n\n        Executed after _initialize_worker() across all workers via Ray remote call.\n        This method include:\n        (1) model weight loading,\n        (2) weight sharding_manager setup,\n        (3) async/multi-agent init,\n        (4) validator init,\n        (5) metrics collector init,\n        (6) checkpoint restoration\n        \"\"\"\n\n        self._load_model_weights()\n\n        self._setup_sharding_manager()\n\n        self._setup_async_rollout()\n\n        self._setup_multi_agent_loop()\n\n        self._init_validator()\n\n        self._init_metrics_collector()\n\n        self._init_checkpoint_manager()\n        self.global_steps = self.checkpoint_manager.load_checkpoint()\n\n        dist.barrier(self._gather_group)\n\n    def _load_model_weights(self):\n        \"\"\"Loads model weights to GPU for all node workers.\"\"\"\n        logger.info(\"Loading model weights for all worker nodes...\")\n        initialized_workers = set()\n\n        for node in self.taskgraph.nodes.values():\n            if not should_create_worker(self.role_worker_mapping, node):\n                continue\n\n            worker_key = generate_node_worker_key(node)\n            if worker_key in initialized_workers:\n                continue\n\n            node_worker = self.workers[worker_key]\n            if not isinstance(node_worker, Worker):\n                raise TypeError(f\"Invalid worker type for node {node.node_id}: {type(node_worker).__name__}\")\n\n            node_worker.init_model()\n            initialized_workers.add(worker_key)\n\n        logger.success(\"All model weights loaded successfully.\")\n\n    def _setup_sharding_manager(self):\n        \"\"\"Sets up sharding managers for actor-rollout weight synchronization.\"\"\"\n        logger.info(f\"Setting up weight sharing infrastructure ({self.config.actor_rollout_ref.rollout.name})...\")\n\n        for agent_group, worker_dict in self.multi_agent_group.items():\n            if NodeRole.ACTOR in worker_dict and NodeRole.ROLLOUT in worker_dict:\n                try:\n                    setup_sharding_manager(\n                        self.config,\n                        self.agent_group_process_group,\n                        agent_group,\n                        worker_dict\n                    )\n                except Exception as e:\n                    logger.error(f\"Failed to set up sharding manager for agent group {agent_group}: {e}\", exc_info=True)\n                    raise\n\n        logger.success(\"Weight sharing infrastructure initialized.\")\n\n    def _setup_async_rollout(self):\n        \"\"\"Initializes async rollout server if configured.\"\"\"\n        if self.config.actor_rollout_ref.rollout.mode != \"async\":\n            return\n\n        logger.info(\"Initializing async rollout server...\")\n        for node in self.taskgraph.nodes.values():\n            if node.node_role == NodeRole.ROLLOUT:\n                self.rollout_mode = \"async\"\n                node_worker = self.workers[generate_node_worker_key(node)]\n                self.zmq_address = node_worker.get_zeromq_address()\n                self.init_async_server(node=node, node_worker=node_worker)\n\n        logger.success(\"Async rollout server initialized.\")\n\n    def _setup_multi_agent_loop(self):\n        \"\"\"Initializes multi-agent loop if in multi-agent mode.\"\"\"\n        if not self._multi_agent:\n            return\n\n        logger.info(\"Initializing multi-agent loop...\")\n        from siirl.execution.rollout_flow.multi_agent.multiagent_generate import MultiAgentLoop\n\n        self.multi_agent_loop = MultiAgentLoop(\n            self,\n            config=self.config.actor_rollout_ref,\n            node_workers=self.workers,\n            local_dag=self.taskgraph,\n            databuffer=self.data_buffers,\n            placement_mode='colocate'\n        )\n\n        logger.success(\"Multi-agent loop initialized.\")\n\n    def _init_validator(self):\n        \"\"\"Initializes validator for validation workflow.\"\"\"\n        logger.info(\"Initializing validator...\")\n        from siirl.dag_worker.validator import Validator\n\n        self.validator = Validator(\n            config=self.config,\n            dataloader=self.dataloader,\n            validate_tokenizer=self.validate_tokenizer,\n            multi_agent_group=self.multi_agent_group,\n            rollout_mode=self.rollout_mode,\n            async_rollout_manager=self._async_rollout_manager,\n            multi_agent_loop=getattr(self, 'multi_agent_loop', None),\n            multi_agent=self._multi_agent,\n            rank=self._rank,\n            world_size=self.world_size,\n            gather_group=self._gather_group,\n            first_rollout_node=self.first_rollout_node,\n            get_node_dp_info_fn=self._get_node_dp_info,\n            enable_perf=self.enable_perf,\n            metric_worker=self.metric_worker\n        )\n        logger.success(\"Validator initialized.\")\n\n    def _init_metrics_collector(self):\n        \"\"\"Initializes metrics collector for training metrics aggregation.\"\"\"\n        logger.info(\"Initializing metrics collector...\")\n        # from siirl.dag_worker.metrics_collector import MetricsCollector\n        # self.metric_worker.init()\n        # self.metrics_collector = MetricsCollector(\n        #     rank=self._rank,\n        #     world_size=self.world_size,\n        #     gather_group=self._gather_group,\n        #     taskgraph=self.taskgraph,\n        #     first_rollout_node=self.first_rollout_node,\n        #     get_node_dp_info_fn=self._get_node_dp_info,\n        #     multi_agent=self._multi_agent,\n        #     enable_perf=self.enable_perf,\n        # )\n        # logger.success(\"Metrics collector initialized.\")\n\n    def _init_checkpoint_manager(self):\n        \"\"\"Initializes checkpoint manager for saving/loading training state.\"\"\"\n        logger.info(\"Initializing checkpoint manager...\")\n        self.checkpoint_manager = CheckpointManager(\n            config=self.config,\n            rank=self._rank,\n            gather_group=self._gather_group,\n            workers=self.workers,\n            taskgraph=self.taskgraph,\n            dataloader=self.dataloader,\n            first_rollout_node=self.first_rollout_node,\n            get_node_dp_info_fn=self._get_node_dp_info\n        )\n\n    def init_async_server(self, node:Node, node_worker):\n        #gather zmq_address to rank_0\n        _, dp_rank, tp_rank, tp_size, *_ = self._get_node_dp_info(node)\n        addr_len = len(self.zmq_address)\n        encoded_addr = torch.tensor([ord(c) for c in self.zmq_address], dtype=torch.uint8,\n                                device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\"))\n        zmq_addresses = []\n        if tp_rank == 0:\n            group_addrs = torch.zeros((tp_size, addr_len), dtype=torch.uint8, device=encoded_addr.device)\n            group_addrs[0] = encoded_addr\n            for i in range(1, tp_size):\n                src_rank = dp_rank * tp_size + i\n                dist.recv(group_addrs[i], src=src_rank)\n            for i in range(tp_size):\n                addr_str = ''.join([chr(c.item()) for c in group_addrs[i]])\n                zmq_addresses.append(addr_str)\n        else:\n            dist.send(encoded_addr, dst=dp_rank * tp_size)\n        if tp_rank == 0:\n            self._async_rollout_manager = AgentLoopManager(node.config[\"intern_config\"], dp_rank, os.environ['WG_PREFIX'], node_worker.rollout, zmq_addresses)\n\n# ==========================================================================================\n# Module 4: Utilities\n# ==========================================================================================\n    def put_data_to_buffers(\n        self, key: str,\n        data: TensorDict,\n        source_dp_size:int,\n        dest_dp_size: int,\n        enforce_buffer: bool,\n        timing_raw: Dict[str, float]\n    ):\n        \"\"\"\n        Puts data into the DataCoordinator by converting it into individual Samples.\n        The data is tagged with a 'key' to be retrieved by the correct downstream node.\n        \"\"\"\n        try:\n            batch_size = len(data) if data is not None else 0\n\n            if source_dp_size == dest_dp_size and not enforce_buffer:\n                with timer(self.enable_perf, f\"put_intern_data_{key}\", timing_raw):\n                    self.internal_data_cache[key] = data\n            else:\n                samples = Dict2Samples(data)\n                if not samples:\n                    logger.warning(f\"Rank {self._rank}: TensorDict for key '{key}' converted to 0 samples. Nothing to put.\")\n                    return\n\n                with timer(self.enable_perf, f\"put_samples_to_coordinator_{key}\", timing_raw):\n                    sample_infos = []\n                    for sample in samples:\n                        # Convert uid to string (handle tensor uid from postprocess_sampling)\n                        uid_val = getattr(sample, 'uid', uuid.uuid4().int)\n                        if isinstance(uid_val, torch.Tensor):\n                            uid_str = str(uid_val.item())  # Works for both int and string tensors\n                        elif hasattr(uid_val, 'tolist'):\n                            uid_str = str(uid_val.tolist())  # Handle numpy types\n                        else:\n                            uid_str = str(uid_val)\n                        \n                        sample_infos.append(SampleInfo(\n                            sum_tokens=getattr(sample, 'sum_tokens', int(sample.attention_mask.sum())),\n                            prompt_length=getattr(sample, 'prompt_length', 0),\n                            response_length=getattr(sample, 'response_length', 0),\n                            uid=uid_str,\n                            dict_info={\n                                'key': key,\n                                'source_dp_size': source_dp_size  # Store source DP size\n                            }\n                        ))\n                    \n                    # Although ray.put is called multiple times, it is more efficient than remote actor calls.\n                    # This is the main source of the remaining overhead, but it is necessary\n                    # to maintain sample-level traceability in the DataCoordinator.\n                    with timer(self.enable_perf, f\"ray_put_samples_{key}\", timing_raw):\n                        sample_refs = [ray.put(sample) for sample in samples]\n                    self.sample_ref_cache.extend(sample_refs)\n                    try:\n                        loop = asyncio.get_event_loop()\n                    except RuntimeError:\n                        loop = asyncio.new_event_loop()\n                        asyncio.set_event_loop(loop)\n\n                    caller_node_id = ray.get_runtime_context().get_node_id()\n\n                    put_future = self.data_coordinator.put_batch.remote(sample_infos, sample_refs, caller_node_id)\n                    loop.run_until_complete(put_future)\n\n                    if self._rank == 0:\n                        logger.info(f\"Rank 0: PUT {len(samples)} samples to DataCoordinator for '{key}'\")\n\n        except Exception as e:\n            logger.error(f\"Rank {self._rank}: Unexpected error in put_data_to_buffers for key '{key}': {e}\", exc_info=True)\n            raise\n\n    def get_data_from_buffers(\n        self,\n        key: str,\n        cur_dp_size: int,\n        cur_dp_rank: int,\n        timing_raw: Dict[str, float]\n    ) -> Optional[TensorDict]:\n        \"\"\"\n        Gets data from the DataCoordinator by filtering for a specific key,\n        then collates the resulting Samples back into a single TensorDict.\n        \n        Args:\n            key: The key to filter samples\n            cur_dp_size: Current node's DP size\n            cur_dp_rank: Current worker's DP rank\n            timing_raw: Timing dict for performance tracking\n        \"\"\"\n        with timer(self.enable_perf, f\"get_intern_data_{key}\", timing_raw):\n            if key in self.internal_data_cache:\n                cached_data = self.internal_data_cache.pop(key)\n                return cached_data\n        def key_filter(sample_info: SampleInfo) -> bool:\n            return sample_info.dict_info.get('key') == key\n\n        try:\n            loop = asyncio.get_event_loop()\n        except RuntimeError:\n            loop = asyncio.new_event_loop()\n            asyncio.set_event_loop(loop)\n\n        with timer(self.enable_perf, f\"get_samples_from_coordinator_{key}\", timing_raw):\n            try:\n                rollout_n = self.config.actor_rollout_ref.rollout.n if hasattr(self.config, 'actor_rollout_ref') else 1\n            except (AttributeError, KeyError):\n                rollout_n = 1\n            \n            if rollout_n is None or rollout_n < 1:\n                rollout_n = 1\n            \n            adjusted_batch_size = int(self.config.data.train_batch_size * rollout_n / cur_dp_size)\n            \n            logger.debug(\n                f\"Rank {self._rank}: Requesting from DataCoordinator: \"\n                f\"key='{key}', cur_dp={cur_dp_size}, \"\n                f\"adjusted_batch_size={adjusted_batch_size} (train_bs={self.config.data.train_batch_size} * rollout_n={rollout_n} / cur_dp={cur_dp_size})\"\n            )\n            \n            # Use filter_plugin to get only samples with matching key\n            # Use balance_partitions to optimize sample distribution by length\n            # Use cache_key to enable multi-rank caching within the same node\n            sample_refs = loop.run_until_complete(\n                self.data_coordinator.get_batch.remote(\n                    adjusted_batch_size,\n                    cur_dp_rank,\n                    filter_plugin=key_filter,\n                    balance_partitions=cur_dp_size,\n                    cache_key=key\n                )\n            )\n\n        # Check if dynamic sampling is enabled (DAPO/embodied)\n        embodied_sampling = self.config.algorithm.embodied_sampling\n        is_dynamic_sampling = (\n            self.config.algorithm.filter_groups.enable\n            or embodied_sampling.filter_accuracy\n            or embodied_sampling.filter_truncated\n        )\n\n        if not sample_refs:\n            if is_dynamic_sampling:\n                logger.debug(f\"Rank {self._rank}: Waiting for data accumulation for key '{key}' (need {adjusted_batch_size} samples)\")\n            else:\n                logger.warning(f\"Rank {self._rank}: DataCoordinator returned empty list for key '{key}' (adjusted_batch_size={adjusted_batch_size})\")\n            return None\n\n        if self._rank == 0:\n            logger.info(f\"Rank 0: GET {len(sample_refs)} samples from DataCoordinator for '{key}'\")\n\n        with timer(self.enable_perf, f\"ray_get_samples_{key}\", timing_raw):\n            samples = ray.get(sample_refs)\n\n        with timer(self.enable_perf, f\"collate_samples_{key}\", timing_raw):\n            tensordict = Samples2Dict(samples)\n\n        return tensordict\n\n    def reset_data_buffer(self):\n        \"\"\"\n        DEPRECATED with DataCoordinator. The get calls are now consuming.\n        This can be a no-op, but for safety, we could implement a clear if needed.\n        For now, it does nothing as intended.\n        \"\"\"\n        logger.debug(\"`reset_data_buffer` is a no-op with the new DataCoordinator model as gets are consuming.\")\n        if self._rank == 0:\n            self.data_coordinator.reset_cache.remote()\n\n    def _get_node_process_group(self, node: Node) -> ProcessGroup:\n        \"\"\"Retrieves the PyTorch ProcessGroup assigned to a specific graph node.\"\"\"\n        assignment = self.process_group_manager.get_node_assignment(node.node_id)\n        if not (assignment and (name := assignment.get(\"process_group_name\"))):\n            raise ValueError(f\"Process group assignment or name not found for node {node.node_id}.\")\n\n        pg = self.process_groups.get(name)\n        if pg is None:\n            raise ValueError(f\"Process group '{name}' for node {node.node_id} was not created or found.\")\n        return pg\n\n    def _get_node_dp_info(self, node: Node) -> tuple[int, int, int, int, int, int]:\n        \"\"\"\n        Calculates Data Parallel (DP), Tensor Parallel (TP), and Pipeline Parallel (PP) info for a node.\n\n        Returns:\n            tuple: (dp_size, dp_rank, tp_rank, tp_size, pp_rank, pp_size)\n        \"\"\"\n        reference_node = node\n        if node.node_type == NodeType.COMPUTE:\n            # If the node is a COMPUTE type, find its true data source ancestor.\n            ancestor = find_first_non_compute_ancestor(self.taskgraph, node.node_id)\n            if ancestor:\n                reference_node = ancestor\n            else:\n                # If no non-COMPUTE ancestor is found, it's a critical error.\n                raise RuntimeError(f\"Could not find any non-COMPUTE ancestor for COMPUTE node '{node.node_id}'. Please check your DAG graph configuration.\")\n\n        if reference_node.node_type == NodeType.COMPUTE:\n            group_world_size = self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes\n            group_rank = dist.get_rank()\n        else:\n            process_group = self._get_node_process_group(reference_node)\n            group_world_size = dist.get_world_size(process_group)\n            group_rank = dist.get_rank(process_group)\n\n        # Get parallelism configuration based on backend strategy\n        tp_size, pp_size = get_parallelism_config(reference_node)\n\n        # Calculate total parallel size (TP * PP)\n        total_parallel_size = tp_size * pp_size\n\n        if group_world_size % total_parallel_size != 0:\n            raise ValueError(f\"Configuration error for node {node.node_id}: Group world size ({group_world_size}) is not divisible by total parallel size (TP={tp_size} * PP={pp_size} = {total_parallel_size}). Check your parallel configuration.\")\n\n        dp_size = group_world_size // total_parallel_size\n\n        # Calculate ranks within the data parallel group\n        dp_rank = group_rank // total_parallel_size\n\n        # Calculate position within the TP-PP grid\n        local_rank_in_tp_pp_group = group_rank % total_parallel_size\n\n        # For 2D parallelism: ranks are arranged as [PP0_TP0, PP0_TP1, ..., PP0_TP(tp_size-1), PP1_TP0, ...]\n        pp_rank = local_rank_in_tp_pp_group // tp_size\n        tp_rank = local_rank_in_tp_pp_group % tp_size\n\n        return dp_size, dp_rank, tp_rank, tp_size, pp_rank, pp_size\n\n    def get_zeromq_address(self):\n        return self.zmq_address\n\n    def multi_agent_put_log(self, key: str, data: TensorDict, agent_group: int, next_dp_size: int, timing_raw):\n        # This logic needs to be adapted to the new model. For now, it's a warning.\n        logger.warning(\"`multi_agent_put_log` is not yet refactored for DataCoordinator and is a no-op.\")\n        pass\n\n    def check_mode(self):\n        return self.rollout_mode == 'sync' and self._multi_agent == False"
  },
  {
    "path": "siirl/dag_worker/data_structures.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass, field\nfrom typing import Any, Dict\n\nimport torch\n\nfrom tensordict import TensorDict\n\n\n@dataclass\nclass ValidationResult:\n    \"\"\"A structured container for a single validation sample's results.\"\"\"\n\n    input_text: str\n    output_text: str\n    score: float\n    data_source: str\n    reward_tensor: torch.Tensor\n    extra_rewards: Dict[str, Any] = field(default_factory=dict)\n\n\n@dataclass\nclass ValidationPayload:\n    \"\"\"A lightweight, serializable container for validation metrics for efficient gathering.\"\"\"\n\n    input_text: str\n    score: float\n    data_source: str\n    extra_rewards: Dict[str, Any] = field(default_factory=dict)\n\n\n@dataclass\nclass NodeOutput:\n    \"\"\"A standardized return object for all node execution functions.\"\"\"\n\n    batch: TensorDict\n    metrics: Dict[str, Any] = field(default_factory=dict)\n"
  },
  {
    "path": "siirl/dag_worker/metric_aggregator.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n# Copyright 2025, Infrawaves. All rights reserved.\n# \n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom collections import defaultdict\nfrom enum import Enum\nfrom typing import Any, Dict, List, Optional, Union\nimport torch\nimport torch.distributed as dist\n\nfrom siirl.utils.extras.device import get_device_id, get_device_name\n\nclass _ReduceOp(Enum):\n    \"\"\"Enumeration for supported reduction operations.\"\"\"\n\n    SUM = dist.ReduceOp.SUM\n    MAX = dist.ReduceOp.MAX\n    MIN = dist.ReduceOp.MIN\n\n\n# Configuration for metrics that require mean, max, and min aggregation.\n# Format: { \"key_in_local_data\": \"final_metric_prefix\" }\nMETRIC_CONFIG_FULL = {\n    \"score\": \"critic/score\",\n    \"rewards\": \"critic/rewards\",\n    \"advantages\": \"critic/advantages\",\n    \"returns\": \"critic/returns\",\n    \"values\": \"critic/values\",\n    \"response_length\": \"response/length\",\n    \"prompt_length\": \"prompt/length\",\n    \"correct_response_length\": \"response/correct_length\",\n    \"wrong_response_length\": \"response/wrong_length\",\n}\n\n# Configuration for metrics that only require mean aggregation.\n# Format: { \"key_in_local_data\": \"final_metric_prefix\" }\nMETRIC_CONFIG_MEAN_ONLY = {\n    \"response_clip_ratio\": \"response/clip_ratio\",\n    \"prompt_clip_ratio\": \"prompt/clip_ratio\",\n}\n\nclass DistributedMetricAggregator:\n    \"\"\"\n    A helper class to encapsulate the logic for aggregating metrics\n    in a distributed environment.\n    \"\"\"\n\n    def __init__(\n        self, local_metrics: Dict[str, Union[float, List[float], torch.Tensor]], group: Optional[dist.ProcessGroup]\n    ):\n        \"\"\"\n        Initializes the aggregator and prepares metrics for reduction.\n\n        Args:\n            local_metrics: The dictionary of metrics on the local rank.\n            group: The process group for distributed communication.\n        \"\"\"\n        self.group = group\n        device_name = get_device_name()\n        if device_name in [\"cuda\", \"npu\"]:\n            self.device = f\"{device_name}:{get_device_id()}\"\n        else:\n            self.device = \"cpu\"\n        self.op_buckets = self._bucket_local_metrics(local_metrics)\n\n    def _bucket_local_metrics(self, metrics: Dict, expected_keys: set = None) -> defaultdict:\n        \"\"\"\n        Parses local metrics and groups them by the required reduction operation.\n        This step also performs local pre-aggregation on lists and tensors.\n        This version correctly handles multi-element tensors as input.\n        \n        For Pipeline Parallel (PP), different stages may have different metrics.\n        This method ensures all ranks have the same set of keys by adding missing\n        metrics with default values (0.0) to avoid tensor shape mismatch in all_reduce.\n\n        Args:\n            metrics: Local metrics dictionary\n            expected_keys: Optional set of all expected metric keys across all ranks\n            \n        Returns:\n            A defaultdict containing keys and pre-aggregated values,\n            grouped by reduction operation type (_ReduceOp).\n        \"\"\"\n        buckets = defaultdict(list)\n        \n        # If expected_keys is provided, ensure all ranks have the same metrics\n        if expected_keys:\n            # define metrics that should be excluded from non-computing ranks\n            # these are training-specific metrics that only the last PP stage should contribute\n            training_metrics = {\n                'actor/pg_loss', 'actor/kl_loss', 'actor/entropy_loss', 'actor/ppo_kl',\n                'actor/pg_clipfrac', 'actor/pg_clipfrac_lower', 'actor/kl_coef',\n                'critic/vf_loss', 'critic/clipfrac'\n            }\n\n            # Token counting metrics should only be contributed by PP rank 0 to avoid double counting\n            token_counting_metrics = {\n                'perf/total_num_tokens/mean'\n            }\n            \n            for key in expected_keys:\n                if key not in metrics:\n                    # for training metrics: use None to indicate this rank shouldn't contribute\n                    # for other metrics: use 0.0 as default\n                    if any(key.startswith(prefix) for prefix in training_metrics) or key in token_counting_metrics:\n                        # mark as None - will be handled specially in aggregation\n                        metrics[key] = None\n                    else:\n                        # performance metrics get default value 0.0\n                        metrics[key] = 0.0\n        \n        for key in sorted(metrics.keys()):\n            value = metrics[key]\n            \n            # Skip None values (training metrics from non-contributing ranks)\n            if value is None:\n                # for training metrics that this rank (those ranks that are not the last PP stage) shouldn't contribute to,\n                # add with count=0 so it doesn't affect the average\n                buckets[_ReduceOp.SUM].append((key, (0.0, 0)))\n                continue\n\n            # Determine if the value is a list or a tensor that needs aggregation\n            is_list = isinstance(value, list)\n            is_tensor = isinstance(value, torch.Tensor)\n\n            if \"_max\" in key:\n                op_type = _ReduceOp.MAX\n                if is_tensor:\n                    # Use torch.max for tensors, get the scalar value\n                    local_val = torch.max(value).item() if value.numel() > 0 else 0.0\n                elif is_list:\n                    local_val = max(value) if value else 0.0\n                else: # Is a scalar float\n                    local_val = value\n                buckets[op_type].append((key, local_val))\n\n            elif \"_min\" in key:\n                op_type = _ReduceOp.MIN\n                if is_tensor:\n                    local_val = torch.min(value).item() if value.numel() > 0 else 0.0\n                elif is_list:\n                    local_val = min(value) if value else 0.0\n                else:\n                    local_val = value\n                buckets[op_type].append((key, local_val))\n\n            else:  # Default to mean calculation (SUM operation).\n                op_type = _ReduceOp.SUM\n                if is_tensor:\n                    local_sum = torch.sum(value).item()\n                    local_count = value.numel()\n                elif is_list:\n                    local_sum = sum(value) if value else 0.0\n                    local_count = len(value)\n                else: # Is a scalar float\n                    local_sum = value\n                    local_count = 1\n                buckets[op_type].append((key, (local_sum, local_count)))\n        return buckets\n\n    def aggregate_and_get_results(self) -> Dict[str, float]:\n        \"\"\"\n        Performs the distributed all_reduce operations and composes the final\n        metrics dictionary.\n\n        Returns:\n            A dictionary with the globally aggregated metrics.\n        \"\"\"\n        final_metrics = {}\n        for op_type, data in self.op_buckets.items():\n            if not data:\n                continue\n\n            keys, values = zip(*data)\n\n            if op_type == _ReduceOp.SUM:\n                sums, counts = zip(*values)\n                sum_tensor = torch.tensor(sums, dtype=torch.float32, device=self.device)\n                count_tensor = torch.tensor(counts, dtype=torch.float32, device=self.device)\n\n                if self.group is not None:\n                    dist.all_reduce(sum_tensor, op=op_type.value, group=self.group)\n                    dist.all_reduce(count_tensor, op=op_type.value, group=self.group)\n\n                global_sums = sum_tensor.cpu().numpy()\n                global_counts = count_tensor.cpu().numpy()\n\n                for i, key in enumerate(keys):\n                    final_metrics[key] = global_sums[i] / global_counts[i] if global_counts[i] > 0 else 0.0\n            else:  # MAX or MIN operations\n                value_tensor = torch.tensor(values, dtype=torch.float32, device=self.device)\n                if self.group is not None:\n                    dist.all_reduce(value_tensor, op=op_type.value, group=self.group)\n\n                global_values = value_tensor.cpu().numpy()\n                for i, key in enumerate(keys):\n                    final_metrics[key] = global_values[i]\n\n        return final_metrics\n\n\n"
  },
  {
    "path": "siirl/dag_worker/metrics_collector.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nMetricsCollector: Orchestrates metrics aggregation for DAG worker training steps.\n\nThis module provides the MetricsCollector class which handles the collection,\naggregation, and formatting of training metrics across distributed workers.\n\"\"\"\n\nimport os\nimport psutil\nimport torch\nimport torch.distributed as dist\nfrom typing import Dict, Callable, Any\nfrom torch.distributed import ProcessGroup\nfrom tensordict import TensorDict\n# Handle different tensordict versions - NonTensorData location varies\ntry:\n    from tensordict import NonTensorData\nexcept ImportError:\n    from tensordict.tensorclass import NonTensorData\n\nfrom siirl.execution.dag import TaskGraph\nfrom siirl.execution.dag.node import Node, NodeRole\nfrom siirl.dag_worker.metric_aggregator import METRIC_CONFIG_FULL, METRIC_CONFIG_MEAN_ONLY\nfrom siirl.dag_worker.dag_utils import (\n    timer,\n    prepare_local_batch_metrics,\n    reduce_and_broadcast_metrics,\n    remove_prefix_from_dataproto,\n    add_prefix_to_dataproto,\n    add_prefix_to_metrics,\n)\nfrom siirl.utils.extras.device import get_device_name, get_device_id\nfrom siirl.utils.metrics.metric_utils import compute_throughout_metrics, compute_timing_metrics\nfrom loguru import logger\n\nclass MetricsCollector:\n    \"\"\"\n    Collects and aggregates training metrics across distributed workers.\n\n    This class orchestrates the final metrics collection for each training step,\n    including batch metrics, performance metrics, and timing information. It uses\n    the existing four-layer metrics architecture:\n    - Layer 1: torch.distributed (communication)\n    - Layer 2: DistributedMetricAggregator (aggregation engine)\n    - Layer 3: dag_utils functions (utility layer)\n    - Layer 4: MetricsCollector (orchestration layer)\n    \"\"\"\n\n    def __init__(\n        self,\n        rank: int,\n        world_size: int,\n        gather_group: ProcessGroup,\n        taskgraph: TaskGraph,\n        first_rollout_node: Node,\n        get_node_dp_info_fn: Callable,\n        multi_agent: bool = False,\n        enable_perf: bool = False,\n    ):\n        \"\"\"\n        Initialize MetricsCollector with explicit parameters.\n\n        Args:\n            rank: Current process rank\n            world_size: Total number of processes\n            gather_group: Process group for metric aggregation\n            taskgraph: DAG task graph for node iteration\n            first_rollout_node: Reference rollout node for configuration\n            get_node_dp_info_fn: Function to get node DP/TP/PP info\n            multi_agent: Whether in multi-agent mode\n            enable_perf: Whether to enable performance timing\n        \"\"\"\n        self.rank = rank\n        self.world_size = world_size\n        self.gather_group = gather_group\n        self.taskgraph = taskgraph\n        self.first_rollout_node = first_rollout_node\n        self.get_node_dp_info_fn = get_node_dp_info_fn\n        self.multi_agent = multi_agent\n        self.enable_perf = enable_perf\n\n    def collect_final_metrics(self, batch: TensorDict, timing_raw: dict) -> Dict[str, float]:\n        \"\"\"\n        Orchestrates the collection and computation of all metrics for a training step\n        using a highly efficient, all_reduce-based aggregation strategy.\n\n        This function replaces the old `compute -> reduce -> finalize` pipeline.\n\n        Args:\n            batch: Final batch data (TensorDict) containing all computed values\n            timing_raw: Dictionary of raw timing measurements\n\n        Returns:\n            Dictionary of aggregated metrics\n        \"\"\"\n        device_name = get_device_name()\n        if device_name == \"cuda\":\n            torch.cuda.reset_peak_memory_stats()\n        elif device_name == \"npu\":\n            torch.npu.reset_peak_memory_stats()\n\n        final_metrics = {}\n\n        # --- 1. Prepare all local metric data ---\n        use_critic = any(node.node_role == NodeRole.CRITIC for node in self.taskgraph.nodes.values())\n        local_data = prepare_local_batch_metrics(batch, use_critic=use_critic)\n\n        # --- 2. Build the dictionary for our generic, high-performance aggregator ---\n        # We want mean, max, and min for most standard metrics.\n        metrics_to_aggregate = {}\n\n        # Process metrics requiring mean, max, and min\n        for key, prefix in METRIC_CONFIG_FULL.items():\n            if key in local_data:\n                # The aggregator determines the operation from the key.\n                # We provide the same raw tensor for mean, max, and min calculations.\n                metrics_to_aggregate[f\"{prefix}/mean\"] = local_data[key]\n                metrics_to_aggregate[f\"{prefix}_max\"] = local_data[key]\n                metrics_to_aggregate[f\"{prefix}_min\"] = local_data[key]\n\n        # Process metrics requiring only mean\n        for key, prefix in METRIC_CONFIG_MEAN_ONLY.items():\n            if key in local_data:\n                metrics_to_aggregate[f\"{prefix}/mean\"] = local_data[key]\n\n        representative_actor_node = next(\n            (n for n in self.taskgraph.nodes.values() if n.node_role == NodeRole.ACTOR), self.first_rollout_node\n        )\n        _, _, _, _, pp_rank_in_group, _ = self.get_node_dp_info_fn(representative_actor_node)\n        # (1) For TP: we have already taken TP into account when we set global_token_num in compute_reward.\n        # see: siirl/workers/dag_worker/mixins/node_executors_mixin.py:compute_reward\n        # (2) For PP: only PP rank 0 contributes to avoid double counting within PP groups\n        # The aggregation will average across DP groups and multiply by world size to get global estimate\n        if pp_rank_in_group == 0:\n            local_token_sum = sum(batch[\"global_token_num\"])\n            metrics_to_aggregate[\"perf/total_num_tokens/mean\"] = float(local_token_sum)\n\n        # --- 3. Perform the aggregated, distributed reduction ---\n        with timer(self.enable_perf, \"metrics_aggregation\", timing_raw):\n            aggregated_metrics = reduce_and_broadcast_metrics(metrics_to_aggregate, self.gather_group)\n\n        # Post-process keys and values for the final output\n        for key, value in aggregated_metrics.items():\n            if \"_max\" in key and \"mem\" not in key:\n                final_metrics[key.replace(\"_max\", \"/max\")] = value\n            elif \"_min\" in key:\n                final_metrics[key.replace(\"_min\", \"/min\")] = value\n            else:\n                final_metrics[key] = value\n\n        # Special handling for total_num_tokens to convert mean back to sum\n        if \"perf/total_num_tokens/mean\" in final_metrics:\n            final_metrics[\"perf/total_num_tokens\"] = final_metrics.pop(\n                \"perf/total_num_tokens/mean\"\n            ) * dist.get_world_size(self.gather_group)\n\n        # --- 4. Handle special cases like Explained Variance ---\n        if use_critic:\n            # Determine the correct device for distributed operations\n            device_name = get_device_name()\n            if device_name in [\"cuda\", \"npu\"]:\n                device = f\"{device_name}:{get_device_id()}\"\n            else:\n                # Fallback to the device of an existing tensor. If it's CPU, all_reduce will fail,\n                # which is the original problem, indicating a deeper issue.\n                device = local_data[\"returns\"].device\n            # These components only need to be summed. We can do a direct all_reduce.\n            components_to_sum = {k: v for k, v in local_data.items() if k.endswith(\"_comp\")}\n            for tensor in components_to_sum.values():\n                if self.gather_group is not None:\n                    dist.all_reduce(tensor.to(device), op=dist.ReduceOp.SUM, group=self.gather_group)\n\n            # Now all ranks have the global sums and can compute the final value.\n            N = local_data[\"returns\"].numel()\n            total_N_tensor = torch.tensor([N], dtype=torch.int64, device=local_data[\"returns\"].device)\n            if self.gather_group is not None:\n                dist.all_reduce(total_N_tensor.to(device), op=dist.ReduceOp.SUM, group=self.gather_group)\n            global_N = total_N_tensor.item()\n\n            if global_N > 0:\n                global_returns_sum = final_metrics[\"critic/returns/mean\"] * global_N\n                global_returns_sq_sum = components_to_sum[\"returns_sq_sum_comp\"].item()\n                global_error_sum = components_to_sum[\"error_sum_comp\"].item()\n                global_error_sq_sum = components_to_sum[\"error_sq_sum_comp\"].item()\n\n                mean_returns = global_returns_sum / global_N\n                var_returns = (global_returns_sq_sum / global_N) - (mean_returns**2)\n\n                mean_error = global_error_sum / global_N\n                var_error = (global_error_sq_sum / global_N) - (mean_error**2)\n\n                final_metrics[\"critic/vf_explained_var\"] = 1.0 - var_error / (var_returns + 1e-8)\n            else:\n                final_metrics[\"critic/vf_explained_var\"] = 0.0\n\n        # --- 5. Add timing and other rank-0-only metrics ---\n        # Only rank 0 needs to compute these for logging.\n        if self.rank == 0:\n            batch[\"global_token_num\"] = NonTensorData([final_metrics.get(\"perf/total_num_tokens\", 0)])\n            final_metrics.update(compute_throughout_metrics(batch, timing_raw, dist.get_world_size()))\n            final_metrics[\"perf/process_cpu_mem_used_gb\"] = psutil.Process(os.getpid()).memory_info().rss / (1024**3)\n            timing_metrics = compute_timing_metrics(batch, timing_raw)\n            for key, value in timing_metrics.items():\n                if key.startswith(\"timing_s/\"):\n                    final_metrics[key.replace(\"timing_s/\", \"perf/delta_time/\")] = value\n\n            # Calculate rollout and actor log probs difference statistics\n            if \"rollout_log_probs\" in batch and \"old_log_probs\" in batch:\n                rollout_probs = torch.exp(batch[\"rollout_log_probs\"])\n                actor_probs = torch.exp(batch[\"old_log_probs\"])\n                rollout_probs_diff = torch.masked_select(\n                    torch.abs(rollout_probs.cpu() - actor_probs),\n                    batch[\"response_mask\"].bool().cpu()\n                )\n                if rollout_probs_diff.numel() > 0:\n                    final_metrics.update({\n                        \"training/rollout_probs_diff_max\": torch.max(rollout_probs_diff).item(),\n                        \"training/rollout_probs_diff_mean\": torch.mean(rollout_probs_diff).item(),\n                        \"training/rollout_probs_diff_std\": torch.std(rollout_probs_diff).item()\n                    })\n\n        # All ranks return the final metrics. Ranks other than 0 can use them if needed,\n        # or just ignore them. This is cleaner than returning an empty dict.\n        return final_metrics\n\n    def collect_multi_agent_final_metrics(self, batch: TensorDict, ordered_metrics: list, timing_raw: dict) -> list:\n        \"\"\"\n        Collects final metrics for multi-agent mode by iterating through agent rollout nodes.\n\n        Args:\n            batch: Final batch data (TensorDict) with multi-agent prefixes\n            ordered_metrics: List of (key, value) tuples to extend\n            timing_raw: Dictionary of raw timing measurements\n\n        Returns:\n            Extended list of ordered metrics\n        \"\"\"\n        node_queue = self.taskgraph.get_entry_nodes()\n        visited_nodes = set()\n        while node_queue:\n            cur_node = node_queue.pop(0)\n            if cur_node.node_id in visited_nodes:\n                continue\n            if cur_node.node_role != NodeRole.ROLLOUT:\n                break\n            batch = remove_prefix_from_dataproto(batch, cur_node)\n            final_metrics = self.collect_final_metrics(batch, timing_raw)\n            final_metrics = add_prefix_to_metrics(final_metrics, cur_node)\n            if final_metrics:\n                ordered_metrics.extend(sorted(final_metrics.items()))\n            if next_nodes := self.taskgraph.get_downstream_nodes(cur_node.node_id):\n                for n in next_nodes:\n                    if n.node_id not in visited_nodes:\n                        node_queue.append(n)\n            batch = add_prefix_to_dataproto(batch, cur_node)\n        return ordered_metrics\n"
  },
  {
    "path": "siirl/dag_worker/validator.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\nimport asyncio\nimport torch\nimport numpy as np\nimport torch.distributed as dist\nfrom ray.actor import ActorHandle\nfrom collections import defaultdict\nfrom typing import Dict, List, Callable, Any, Optional, Tuple\nfrom loguru import logger\nfrom tensordict import TensorDict\n# Handle different tensordict versions - NonTensorData location varies\ntry:\n    from tensordict import NonTensorData\nexcept ImportError:\n    from tensordict.tensorclass import NonTensorData\nfrom torch.distributed import ProcessGroup\n\nfrom siirl.data_coordinator import preprocess_dataloader\nfrom siirl.data_coordinator.dataloader import DataLoaderNode\nfrom siirl.execution.dag.node import NodeRole, Node\nfrom siirl.execution.scheduler.reward import create_reward_manager\nfrom siirl.dag_worker.data_structures import ValidationPayload, ValidationResult\nfrom siirl.dag_worker.dag_utils import dump_validation_generations, timer\nfrom siirl.utils.metrics.metric_utils import aggregate_validation_metrics\nfrom siirl.params import SiiRLArguments\n\n\n\nclass Validator:\n    \"\"\"\n    Handles the complete validation workflow for distributed RL training.\n\n    This class orchestrates the validation process including:\n    - Batch preparation and generation\n    - Reward scoring and result packaging\n    - Metrics aggregation across all ranks\n    - Performance logging\n\n    The validator operates in a distributed manner, coordinating across multiple\n    ranks and aggregating results on rank 0.\n    \"\"\"\n\n    def __init__(\n        self,\n        config: SiiRLArguments,\n        dataloader: DataLoaderNode,\n        validate_tokenizer: Any,\n        multi_agent_group: Dict[int, Dict[NodeRole, Any]],\n        rollout_mode: str,\n        async_rollout_manager: Optional[Any],\n        multi_agent_loop: Optional[Any],\n        multi_agent: bool,\n        rank: int,\n        world_size: int,\n        gather_group: ProcessGroup,\n        first_rollout_node: Node,\n        get_node_dp_info_fn: Callable,\n        enable_perf: bool = False,\n        metric_worker: ActorHandle = None\n    ):\n        \"\"\"\n        Initialize the Validator with explicit dependencies.\n\n        Args:\n            config: Training configuration\n            dataloader: Data loading utilities\n            val_reward_fn: Validation reward function\n            validate_tokenizer: Tokenizer for decoding sequences\n            multi_agent_group: Worker groups for generation (indexed by agent_group -> role)\n            rollout_mode: Generation mode ('sync' or 'async')\n            async_rollout_manager: Manager for async rollout (None if not using async)\n            multi_agent_loop: Manager for multi-agent generation (None if not multi-agent)\n            multi_agent: Whether in multi-agent mode\n            rank: Current process rank\n            world_size: Total number of processes\n            gather_group: Process group for distributed gathering\n            first_rollout_node: First rollout node for getting DP/TP/PP info\n            get_node_dp_info_fn: Function to get node parallelism info\n            enable_perf: Whether to enable performance profiling\n        \"\"\"\n        self.config = config\n        self.dataloader = dataloader\n        self.validate_tokenizer = validate_tokenizer\n        self.multi_agent_group = multi_agent_group\n        self.rollout_mode = rollout_mode\n        self.async_rollout_manager = async_rollout_manager\n        self.multi_agent_loop = multi_agent_loop\n        self.multi_agent = multi_agent\n        self.rank = rank\n        self.world_size = world_size\n        self.gather_group = gather_group\n        self.first_rollout_node = first_rollout_node\n        self.get_node_dp_info_fn = get_node_dp_info_fn\n        self.enable_perf = enable_perf\n        \n        self.val_reward_fn = create_reward_manager(\n            self.config,\n            self.validate_tokenizer,\n            num_examine=1,\n            max_resp_len=self.config.data.max_response_length,\n            overlong_buffer_cfg=self.config.reward_model.overlong_buffer,\n        )\n\n        # Validation timing tracking\n        self.val_timedict = defaultdict(float)\n        self.metric_worker = metric_worker\n\n    def validate(self, global_step: int) -> Dict[str, float]:\n        \"\"\"Performs validation based on dataset type.\"\"\"\n        # Correctly use the existing dataset_type parameter from the data config.\n        dataset_type = getattr(self.config.data, \"dataset_type\", \"llm\")\n        if dataset_type == \"embodied\":\n            return self._validate_embodied(global_step)\n        else:\n            return self._validate_text_generation(global_step)\n    \n    def _validate_embodied(self, global_step) -> Dict[str, float]:\n        \"\"\"\n        Performs embodied validation by running interactive episodes.\n        \n        This is the main entry point that orchestrates the entire validation flow:\n        1. Initialize timers and check prerequisites\n        2. Iterate through validation batches (each rank processes a shard)\n        3. Generate embodied episodes via rollout worker\n        4. Score results using val_reward_fn\n        5. Gather payloads from all ranks to rank 0\n        6. Aggregate and return final metrics\n        \n        Returns:\n            Dict[str, float]: Validation metrics (only on rank 0, empty dict on other ranks)\n        \"\"\"\n        # 1. Initialize timers\n        self.timers = defaultdict(float)\n        if self.rank == 0:\n            logger.info(\"=\" * 60)\n            logger.info(f\"Starting Embodied Validation @ Global Step {global_step}...\")\n            logger.info(\"=\" * 60)\n            self.timers[\"overall_start_time\"] = time.perf_counter()\n        \n        # 2. Check if num_val_batches > 0 to avoid unnecessary loops\n        if self.dataloader.num_val_batches <= 0:\n            if self.rank == 0:\n                logger.warning(\"num_val_batches is 0. Skipping embodied validation.\")\n            return {}\n        \n        # 3. Collect payloads from all batches\n        all_payloads = []\n        \n        for i in range(self.dataloader.num_val_batches):\n            if self.rank == 0:\n                logger.debug(f\"Processing embodied validation batch {i + 1}/{self.dataloader.num_val_batches}\")\n            \n            # 3.1 Prepare and generate\n            with timer(self.enable_perf, \"prep_and_generate\", self.val_timedict):\n                batch_proto = self._prepare_embodied_validation_batch()\n                generated_proto = self._generate_for_embodied_validation(batch_proto, global_step)\n                dist.barrier(self.gather_group)\n            \n            # 3.2 Score\n            with timer(self.enable_perf, \"score\", self.val_timedict):\n                batch_payloads = self._score_embodied_results(generated_proto)\n                all_payloads.extend(batch_payloads)\n        # 4. Gather payloads to rank 0 (only TP/PP master ranks prepare payload)\n        dp_size, _, tp_rank, _, pp_rank, _ = self.get_node_dp_info_fn(self.first_rollout_node)\n        with timer(self.enable_perf, \"gather_payloads\", self.val_timedict):\n            if tp_rank == 0 and pp_rank == 0:\n                self.metric_worker.submit_metric(self._aggregate_and_log_embodied_metrics(all_payloads, global_step), dp_size)\n        # with timer(self.enable_perf, \"gather_payloads\", self.val_timedict):\n        #     payloads_for_metrics = []\n        #     if tp_rank == 0 and pp_rank == 0:\n        #         payloads_for_metrics = all_payloads\n            \n        #     gathered_payloads_on_rank0 = [None] * self.world_size if self._rank == 0 else None\n        #     dist.gather_object(payloads_for_metrics, gathered_payloads_on_rank0, dst=0, group=self._gather_group)\n        \n        # # 5. Rank 0 aggregates and logs metrics\n        # if self.rank == 0:\n        #     flat_payload_list = [p for sublist in gathered_payloads_on_rank0 if sublist for p in sublist]\n        #     final_metrics = self._aggregate_and_log_embodied_metrics(flat_payload_list)\n        \n        dist.barrier(self.gather_group)\n        \n        return\n    \n    \n    def _validate_text_generation(self, global_step: int) -> Dict[str, float]:\n        \"\"\"\n        Executes the complete validation workflow.\n\n        This is the main entry point for validation. It:\n        1. Prepares validation batches from the dataloader\n        2. Generates sequences using the rollout model\n        3. Scores the generated sequences\n        4. Aggregates metrics across all ranks\n        5. Logs performance breakdown (on rank 0)\n\n        Args:\n            global_step: Current training step (for logging and checkpointing)\n\n        Returns:\n            Dict[str, float]: Validation metrics (only on rank 0, empty dict on other ranks)\n        \"\"\"\n        self.val_timedict = defaultdict(float)\n        if self.rank == 0:\n            logger.info(\"=\" * 60)\n            logger.info(f\"Starting Validation @ Global Step {global_step}...\")\n            logger.info(\"=\" * 60)\n            self.val_timedict[\"overall_start_time\"] = time.perf_counter()\n\n        all_scored_results: List[ValidationResult] = []\n\n        # Check if num_val_batches > 0 to avoid unnecessary loops.\n        if self.dataloader.num_val_batches <= 0:\n            if self.rank == 0:\n                logger.warning(\"num_val_batches is 0. Skipping validation.\")\n            return {}\n        sample_turns = []\n        for i in range(self.dataloader.num_val_batches):\n            if self.rank == 0:\n                logger.debug(f\"Processing validation batch {i + 1}/{self.dataloader.num_val_batches}\")\n\n            with timer(self.enable_perf, \"prep_and_generate\", self.val_timedict):\n                test_batch = self.dataloader.run(is_validation_step=True)\n                val_batch = preprocess_dataloader(test_batch, self.config.actor_rollout_ref.rollout.val_kwargs.n)\n                generated_proto = self._generate_for_validation(val_batch)\n                dist.barrier(self.gather_group)\n\n            with timer(self.enable_perf, \"score_and_package\", self.val_timedict):\n                scored_results = self._score_and_package_results(generated_proto)\n                all_scored_results.extend(scored_results)\n                    \n            \n        dump_validation_generations(self.config, global_step, self.rank, all_scored_results)\n        dist.barrier(self.gather_group)\n\n        dp_size, _, tp_rank, _, pp_rank, _ = self.get_node_dp_info_fn(self.first_rollout_node)\n        \n        # Gather all payloads to rank 0\n        with timer(self.enable_perf, \"gather_payloads\", self.val_timedict):\n            if tp_rank == 0 and pp_rank == 0:\n                # Only the master rank of the TP group (tp_rank=0) and first PP stage (pp_rank=0) prepares the payload.\n                payloads_for_metrics = [\n                    ValidationPayload(r.input_text, r.score, r.data_source, r.extra_rewards) for r in all_scored_results\n                ]\n                self.metric_worker.submit_metric(self._aggregate_and_log_validation_metrics(payloads_for_metrics), dp_size)\n\n        dist.barrier(self.gather_group)\n        return \n\n    def _generate_for_validation(self, batch: TensorDict) -> TensorDict:\n        \"\"\"\n        Generates sequences using the rollout worker for a validation batch.\n\n        Supports three generation modes:\n        - Sync mode: Direct generation via rollout worker\n        - Async mode: Generation via async rollout manager\n        - Multi-agent mode: Generation via multi-agent loop\n\n        Args:\n            batch_proto: Input batch containing prompts\n\n        Returns:\n            TensorDict: Batch with generated sequences added\n        \"\"\"\n        rollout_worker = self.multi_agent_group[0][NodeRole.ROLLOUT]\n        val_kwargs = self.config.actor_rollout_ref.rollout.val_kwargs\n\n        prompt_texts = self.validate_tokenizer.batch_decode(\n            batch[\"input_ids\"], skip_special_tokens=True\n        )\n        batch[\"prompt_texts\"] = prompt_texts\n        batch[\"eos_token_id\"] = NonTensorData(self.validate_tokenizer.eos_token_id)\n        batch[\"pad_token_id\"] = NonTensorData(self.validate_tokenizer.eos_token_id)\n        batch[\"recompute_log_prob\"] = NonTensorData(self.validate_tokenizer.eos_token_id)\n        batch[\"validate\"] = NonTensorData(True)\n        batch[\"do_sample\"] = NonTensorData(val_kwargs.do_sample)\n\n        output = None\n        if self.multi_agent is False:\n            if self.rollout_mode == 'sync':\n                output = rollout_worker.generate_sequences(batch)\n            elif self.async_rollout_manager:\n                loop = asyncio.get_event_loop()\n                output = loop.run_until_complete(self.async_rollout_manager.generate_sequences(batch))\n        else:\n            output = self.multi_agent_loop.generate_sequence(batch)\n\n        if output is not None:\n            return output\n        return batch\n\n    def _score_and_package_results(self, generated_proto: TensorDict) -> List[ValidationResult]:\n        \"\"\"\n        Scores generated sequences and packages them into ValidationResult objects.\n\n        This method:\n        1. Computes rewards for generated sequences (or uses pre-computed rewards)\n        2. Decodes input prompts and output responses\n        3. Packages everything into ValidationResult objects\n        4. Filters out padded duplicates for trailing ranks\n\n        Args:\n            generated_proto: Batch containing generated sequences\n\n        Returns:\n            List[ValidationResult]: Scored and packaged validation results\n            Dict: extra_rewards \n        \"\"\"\n        if self.rollout_mode == 'async' and self.async_rollout_manager is None:\n            return []\n        if self.multi_agent and 'responses' not in generated_proto:\n            return []\n        if \"token_level_rewards\" in generated_proto:\n            reward_result = {\"reward_tensor\": generated_proto[\"token_level_rewards\"],\n                             \"reward_extra_info\": {}}\n        else:\n            reward_result = self.val_reward_fn(generated_proto, return_dict=True)\n        scores = reward_result[\"reward_tensor\"].sum(-1).cpu()\n\n        input_texts = generated_proto[\"prompt_texts\"] if \"prompt_texts\" in generated_proto else None\n        if input_texts is None:\n            logger.error(\n                \"FATAL: `prompt_texts` not found in `non_tensor_batch`. \"\n                \"The prompt data was lost during the process. Falling back to decoding the full sequence, \"\n                \"but please be aware the resulting `input_text` will be INCORRECT (it will contain prompt + response).\"\n            )\n            # Fallback to prevent a crash, but the output is known to be wrong.\n            input_texts = self.validate_tokenizer.batch_decode(\n                generated_proto[\"input_ids\"], skip_special_tokens=True\n            )\n\n        output_texts = self.validate_tokenizer.batch_decode(generated_proto[\"responses\"], skip_special_tokens=True)\n        data_sources = generated_proto[\"data_source\"] if \"data_source\" in generated_proto else [\"unknown\"] * len(scores)\n        extra_info = generated_proto[\"extra_info\"] if \"data_source\" in generated_proto else [None] * len(scores)\n\n        packaged_results = []\n        for i in range(len(scores)):\n            if self.dataloader.is_val_trailing_rank and isinstance(extra_info[i], dict) and extra_info[i].get(\"padded_duplicate\", None):\n                logger.debug(f\"Rank {self.rank} skip append padded duplicate item {i}: score={scores[i].item()}\")\n                continue\n            extra_rewards = {k: v[i] for k, v in reward_result.get(\"reward_extra_info\", {}).items()}\n            packaged_results.append(ValidationResult(input_texts[i], output_texts[i], scores[i].item(), data_sources[i], reward_result[\"reward_tensor\"][i], extra_rewards))\n        return packaged_results\n\n    def _aggregate_and_log_validation_metrics(self, all_payloads: List[ValidationPayload]) -> Dict[str, float]:\n        \"\"\"\n        Aggregates all validation results and logs performance (rank 0 only).\n\n        This method:\n        1. Calls _aggregate_validation_results to compute final metrics\n        2. Logs a detailed performance breakdown of the validation process\n        3. Reports total validation time\n\n        Args:\n            all_payloads: All validation payloads gathered from all ranks\n\n        Returns:\n            Dict[str, float]: Final aggregated validation metrics\n        \"\"\"\n        if not all_payloads:\n            logger.warning(\"Validation finished with no results gathered on Rank 0 to aggregate.\")\n            return {}\n\n        \n        with timer(self.enable_perf, \"final_aggregation\", self.val_timedict):\n            final_metrics = self._aggregate_validation_results(all_payloads)\n\n        # Log performance breakdown\n        total_time = time.perf_counter() - self.val_timedict.pop(\"overall_start_time\", time.perf_counter())\n        if self.rank == 0:\n            logger.info(f\"Rank 0: Aggregating {len(all_payloads)} validation results...\")\n            logger.info(\"--- Validation Performance Breakdown (Rank 0) ---\")\n            for name, duration in self.val_timedict.items():\n                logger.info(f\"  Total {name.replace('_', ' ').title():<25}: {duration:.4f}s\")\n            known_time = sum(self.val_timedict.values())\n            logger.info(f\"  {'Other/Overhead':<25}: {max(0, total_time - known_time):.4f}s\")\n            logger.info(f\"  {'TOTAL VALIDATION TIME':<25}: {total_time:.4f}s\")\n            logger.info(\"=\" * 51)\n        return final_metrics\n\n    def _aggregate_validation_results(self, all_payloads: List[ValidationPayload]) -> Dict[str, float]:\n        \"\"\"\n        Computes the final metric dictionary from all gathered validation payloads.\n\n        This method processes validation results to compute:\n        - Mean/majority/best metrics for different data sources\n        - Pass@N accuracy metrics\n        - Per-data-source test scores\n\n        Args:\n            all_payloads: All validation payloads from all ranks\n\n        Returns:\n            Dict[str, float]: Final validation metrics organized by data source and metric type\n        \"\"\"\n        data_sources = [p.data_source for p in all_payloads]\n        sample_inputs = [p.input_text for p in all_payloads]\n\n        infos_dict = defaultdict(list)\n        for p in all_payloads:\n            infos_dict[\"reward\"].append(p.score)\n            for key, value in p.extra_rewards.items():\n                infos_dict[key].append(value)\n\n        data_src2var2metric2val = aggregate_validation_metrics(data_sources=data_sources, sample_inputs=sample_inputs, infos_dict=infos_dict)\n\n        metric_dict = {}\n        for data_source, var2metric2val in data_src2var2metric2val.items():\n            core_var = \"acc\" if \"acc\" in var2metric2val else \"reward\"\n            for var_name, metric2val in var2metric2val.items():\n                if not metric2val:\n                    continue\n\n                # Robustly parse '@N' to prevent crashes from malformed metric names.\n                n_max_values = []\n                for name in metric2val.keys():\n                    if \"@\" in name and \"/mean\" in name:\n                        try:\n                            n_val = int(name.split(\"@\")[-1].split(\"/\")[0])\n                            n_max_values.append(n_val)\n                        except (ValueError, IndexError):\n                            continue  # Ignore malformed metric names\n\n                n_max = max(n_max_values) if n_max_values else 1\n\n                for metric_name, metric_val in metric2val.items():\n                    is_core_metric = (var_name == core_var) and any(metric_name.startswith(pfx) for pfx in [\"mean\", \"maj\", \"best\"]) and (f\"@{n_max}\" in metric_name)\n\n                    metric_sec = \"val-core\" if is_core_metric else \"val-aux\"\n                    pfx = f\"{metric_sec}/{data_source}/{var_name}/{metric_name}\"\n                    metric_dict[pfx] = metric_val\n\n        # Re-calculate test_score per data source\n        data_source_rewards = defaultdict(list)\n        for p in all_payloads:\n            data_source_rewards[p.data_source].append(p.score)\n\n        for source, rewards in data_source_rewards.items():\n            if rewards:\n                metric_dict[f\"val/test_score/{source}\"] = np.mean(rewards)\n\n        return metric_dict\n\n    \n    def _prepare_embodied_validation_batch(self) -> TensorDict:\n        \"\"\"\n        Fetches and prepares a single embodied validation batch.\n        \n        Unlike text-generation validation, embodied validation does NOT repeat batches\n        because running embodied episodes is expensive.\n        \n        Returns:\n            TensorDict: The validation batch\n        \"\"\"\n        test_batch = self.dataloader.run(is_validation_step=True)\n        test_batch_proto = preprocess_dataloader(test_batch)\n\n        # No repeat for embodied validation (confirmed by user)\n        return test_batch_proto\n    \n    def _generate_for_embodied_validation(self, batch: TensorDict, global_step:int) -> TensorDict:\n        \"\"\"\n        Generates embodied episodes using the rollout worker.\n        \n        Sets up meta_info for validation mode (validate=True, do_sample=False) and\n        calls the appropriate generation method based on rollout configuration.\n        \n        Args:\n            batch: The input batch containing task information\n            \n        Returns:\n            TensorDict: The batch with generated episode data (actions, observations, rewards, etc.)\n        \"\"\"\n        rollout_worker = self.multi_agent_group[0][NodeRole.ROLLOUT]\n        \n        # Set meta_info for embodied validation\n    \n        batch[\"eos_token_id\"] = NonTensorData(self.validate_tokenizer.eos_token_id)\n        batch[\"pad_token_id\"] = NonTensorData(self.validate_tokenizer.pad_token_id)\n        batch[\"recompute_log_prob\"] = NonTensorData(False)\n        batch[\"validate\"] = NonTensorData(True)\n        batch[\"do_sample\"] = NonTensorData(False)\n        batch[\"global_steps\"] = NonTensorData(global_step)\n        \n        logger.info(\n            f\"[Embodied Validation] Batch variables: \"\n            f\"eos_token_id={batch['eos_token_id']}, \"\n            f\"pad_token_id={batch['pad_token_id']}, \"\n            f\"recompute_log_prob={batch['recompute_log_prob']}, \"\n            f\"validate={batch['validate']}, \"\n            f\"do_sample={batch['do_sample']}, \"\n            f\"global_steps={batch['global_steps']}\"\n        )\n        \n        # Generate episodes based on rollout mode\n        output = None\n        output = rollout_worker.generate_sequences(batch)\n        \n        # Union the output with the original batch\n        if output is not None:\n            return output\n        \n        return batch\n    \n    def _score_embodied_results(self, generated_proto: TensorDict) -> List[ValidationPayload]:\n        \"\"\"\n        Scores generated embodied episodes using val_reward_fn and packages lightweight payloads.\n        \n        Unlike text-generation, embodied validation:\n        - Uses val_reward_fn.verify() instead of val_reward_fn()\n        - Returns (verifier_score, reward_metrics, format_metrics, reward_format_metrics)\n        - Doesn't need to decode text prompts/responses\n        \n        Args:\n            generated_proto: The batch with generated episode data\n            \n        Returns:\n            List[ValidationPayload]: Lightweight payloads for gathering\n        \"\"\"\n        if self.val_reward_fn:\n            verifier_score, reward_metrics, format_metrics, reward_format_metrics = self.val_reward_fn.verify(generated_proto)\n            reward_tensor = torch.tensor(verifier_score, dtype=torch.float32).unsqueeze(-1)\n            \n            # Store batch-level metrics (without prefix, will be added during aggregation)\n            batch_metrics = {\n                'reward_metrics': reward_metrics,\n                'format_metrics': format_metrics,\n                'reward_format_metrics': reward_format_metrics,\n            }\n        \n        # 3. Get data sources (task suite name)\n        task_suite_name = getattr(self.config.actor_rollout_ref.embodied.env, 'env_name', 'unknown_task')\n        data_sources = generated_proto.get(\"data_source\", [task_suite_name] * reward_tensor.shape[0])\n        \n        # 4. Get input identifiers (for debugging/logging)\n        # For embodied tasks, we use task_file_name if available\n        if \"task_file_name\" in generated_proto:\n            task_file_names_bytes = generated_proto[\"task_file_name\"].cpu().numpy()\n            input_texts = []\n            for tfn_bytes in task_file_names_bytes:\n                # Decode bytes to string\n                tfn_str = bytes(tfn_bytes).decode('utf-8').rstrip('\\x00')\n                input_texts.append(f\"Task: {tfn_str}\")\n        else:\n            # Fallback: use generic episode identifiers\n            input_texts = [f\"Episode_{i}\" for i in range(reward_tensor.shape[0])]\n        \n        # 5. Compute scores\n        scores = reward_tensor.sum(-1).cpu()\n        \n        # 6. Package payloads (lightweight, no full reward tensor)\n        # Only the first sample in each batch carries the batch_metrics to avoid duplication\n        packaged_payloads = []\n        for i in range(len(scores)):\n            payload = ValidationPayload(\n                input_text=input_texts[i],\n                score=scores[i].item(),\n                data_source=data_sources[i],\n                extra_rewards=batch_metrics if i == 0 else {}  # Only first sample carries batch metrics\n            )\n            packaged_payloads.append(payload)\n        \n        return packaged_payloads\n    \n    \n    def _aggregate_and_log_embodied_metrics(self, all_payloads: List[ValidationPayload], global_step: int) -> Dict[str, float]:\n        \"\"\"\n        On Rank 0, aggregates all embodied validation results and logs performance.\n        \n        This function runs only on rank 0 after gathering payloads from all ranks.\n        \n        Args:\n            all_payloads: All ValidationPayload objects gathered from all ranks\n            \n        Returns:\n            Dict[str, float]: Final validation metrics\n        \"\"\"\n        if not all_payloads:\n            logger.warning(\"Embodied validation finished with no results gathered on Rank 0.\")\n            return {}\n        \n        logger.info(f\"Rank 0: Aggregating {len(all_payloads)} embodied validation results...\")\n        \n        # 1. Aggregate results\n        with timer(self.enable_perf, \"final_aggregation\", self.val_timedict):\n            final_metrics = self._aggregate_embodied_results(all_payloads)\n        \n        # 2. Log performance breakdown\n        if self.rank == 0:\n            total_time = time.perf_counter() - self.timers.pop(\"overall_start_time\", time.perf_counter())\n            logger.info(\"=\" * 60)\n            logger.info(\"--- Embodied Validation Performance Breakdown (Rank 0) ---\")\n            for name, duration in self.timers.items():\n                logger.info(f\"  Total {name.replace('_', ' ').title():<30}: {duration:.4f}s\")\n            known_time = sum(self.timers.values())\n            logger.info(f\"  {'Other/Overhead':<30}: {max(0, total_time - known_time):.4f}s\")\n            logger.info(f\"  {'TOTAL EMBODIED VALIDATION TIME':<30}: {total_time:.4f}s\")\n            logger.info(\"=\" * 60)\n            \n            # 3. Log final metrics\n            logger.info(f\"Embodied Validation Metrics (Global Step {global_step}):\")\n            for metric_name, metric_value in sorted(final_metrics.items()):\n                logger.info(f\"  {metric_name}: {metric_value:.4f}\")\n            logger.info(\"=\" * 60)\n        \n        return final_metrics\n    \n    def _aggregate_embodied_results(self, all_payloads: List[ValidationPayload]) -> Dict[str, float]:\n        \"\"\"\n        Computes the final metric dictionary from all gathered embodied validation payloads.\n        \n        Aggregation strategy:\n        1. Group rewards by data_source\n        2. Compute mean reward per data_source and overall\n        3. Collect and average batch-level metrics from all batches\n        \n        Args:\n            all_payloads: All ValidationPayload objects from all ranks\n            \n        Returns:\n            Dict[str, float]: Final metrics with structure:\n                - embodied/test_score/{data_source}: Mean reward per data source\n                - embodied/test_score/all: Overall mean reward\n                - embodied/reward/{metric_name}: Averaged reward metrics across all batches\n                - embodied/format/{metric_name}: Averaged format metrics across all batches\n                - embodied/reward_format/{metric_name}: Averaged combined metrics across all batches\n        \"\"\"\n        # 1. Group rewards by data_source\n        data_source_rewards = {}\n        for payload in all_payloads:\n            data_source = payload.data_source\n            if data_source not in data_source_rewards:\n                data_source_rewards[data_source] = []\n            data_source_rewards[data_source].append(payload.score)\n        \n        # 2. Compute per-data-source metrics\n        metric_dict = {}\n        for data_source, rewards in data_source_rewards.items():\n            metric_dict[f'val/test_score/{data_source}'] = np.mean(rewards)\n        \n        # 3. Compute overall mean reward\n        all_rewards = [p.score for p in all_payloads]\n        metric_dict['val/test_score/all'] = np.mean(all_rewards)\n        \n        # 4. Collect and average batch-level metrics from all batches\n        batch_metrics_list = []\n        for payload in all_payloads:\n            if payload.extra_rewards and 'reward_metrics' in payload.extra_rewards:\n                batch_metrics_list.append(payload.extra_rewards)\n        \n        if batch_metrics_list:\n            num_batches = len(batch_metrics_list)\n            logger.info(f\"[_aggregate_embodied_results] Collected metrics from {num_batches} batches\")\n            \n            # 4.1 Collect all reward_metrics (including 'all' and per-task metrics)\n            aggregated_reward_metrics = {}\n            for batch_metrics in batch_metrics_list:\n                for key, value in batch_metrics['reward_metrics'].items():\n                    if key not in aggregated_reward_metrics:\n                        aggregated_reward_metrics[key] = []\n                    aggregated_reward_metrics[key].append(value)\n            \n            # 4.2 Collect all format_metrics\n            aggregated_format_metrics = {}\n            for batch_metrics in batch_metrics_list:\n                for key, value in batch_metrics['format_metrics'].items():\n                    if key not in aggregated_format_metrics:\n                        aggregated_format_metrics[key] = []\n                    aggregated_format_metrics[key].append(value)\n            \n            # 4.3 Collect all reward_format_metrics\n            aggregated_reward_format_metrics = {}\n            for batch_metrics in batch_metrics_list:\n                for key, value in batch_metrics['reward_format_metrics'].items():\n                    if key not in aggregated_reward_format_metrics:\n                        aggregated_reward_format_metrics[key] = []\n                    aggregated_reward_format_metrics[key].append(value)\n            \n            # 5. Compute average values and add to metric_dict\n            for key, values in aggregated_reward_metrics.items():\n                metric_dict[f'embodied/reward/{key}'] = np.mean(values)\n            \n            for key, values in aggregated_format_metrics.items():\n                metric_dict[f'embodied/format/{key}'] = np.mean(values)\n            \n            for key, values in aggregated_reward_format_metrics.items():\n                metric_dict[f'embodied/reward_format/{key}'] = np.mean(values)\n        \n        return metric_dict\n"
  },
  {
    "path": "siirl/data_coordinator/__init__.py",
    "content": "# Copyright (c) 2025, Shanghai Innovation Institute.  All rights reserved.\n\nfrom .protocol import *\nfrom .data_buffer import *\nfrom .sample import *"
  },
  {
    "path": "siirl/data_coordinator/data_buffer.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nfrom typing import Dict, List, Optional, Tuple, Callable, Any\nimport heapq\nimport random\nimport ray\nimport loguru\nimport time\nfrom collections import deque, defaultdict\nfrom ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy\nfrom siirl.data_coordinator.sample import SampleInfo\nfrom siirl.utils.model_utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions\n\n\n@ray.remote\nclass DataCoordinator:\n    \"\"\"\n    A globally unique central Actor responsible for coordinating data producers (RolloutWorkers)\n    and consumers (Trainers). It does not store the actual sample data, only the sample\n    metadata (SampleInfo) and object references (ObjectRef). This allows it to implement\n    complex global sampling strategies at a very low cost.\n    \"\"\"\n    def __init__(self, nnodes: int, ppo_mini_batch_size: int, world_size: int):\n        self.nnodes = nnodes\n        self.ppo_mini_batch_size = ppo_mini_batch_size\n        self.world_size = world_size\n        # Use a deque to store tuples of metadata and references for efficient FIFO operations\n        self._sample_queue: deque[Tuple[SampleInfo, ray.ObjectRef]] = deque()\n        self._put_counter = 0  # Used for round-robin buffer selection\n        self.lock = asyncio.Lock()\n        loguru.logger.info(\"Global DataCoordinator initialized.\")\n        \n        # Cache for multi-rank access: [[rank0_data], [rank1_data], ...]\n        self._cache: List[List[ray.ObjectRef]] = []\n        self._cache_key: Optional[str] = None  # Track which key the cache belongs to\n        \n        # === Statistics tracking for dynamic sampling scenarios ===\n        self._stats_batches_received = 0      # Number of put_batch calls\n        self._stats_samples_received = 0      # Total samples received since last dispatch\n        self._stats_accumulation_start = None # Time when accumulation started\n        self._stats_last_progress_pct = 0     # Last logged progress percentage\n        \n    async def put(self, sample_info: SampleInfo, sample_ref: Any, caller_node_id: Optional[str] = None):\n        \"\"\"\n        Called by a RolloutWorker to register a new sample reference and its metadata.\n        This method automatically routes the ObjectRef to a DataBuffer on its local\n        node to be held.\n        \n        Args:\n            sample_info: Metadata about the sample\n            sample_ref: Ray ObjectRef or the actual sample data\n            caller_node_id: The node ID of the caller. If None, will try to get it from\n                          the runtime context (but this won't work correctly for remote calls)\n        \"\"\"\n        # Due to Ray's small object optimization, an ObjectRef passed by the client\n        # might be automatically resolved to its actual value. Here, we ensure that\n        # we are always handling an ObjectRef.\n        if not isinstance(sample_ref, ray.ObjectRef):\n            sample_ref = ray.put(sample_ref)\n\n        # 1. Get the node ID of the caller\n        # Note: When called remotely, ray.get_runtime_context().get_node_id() returns\n        # the node ID of the DataCoordinator actor, not the caller. So we require the\n        # caller to pass their node_id explicitly.\n        if caller_node_id is None:\n            caller_node_id = ray.get_runtime_context().get_node_id()\n            loguru.logger.warning(\n                \"DataCoordinator.put() called without caller_node_id. \"\n                f\"Using DataCoordinator's node_id {caller_node_id[:16]}... which may be incorrect.\"\n            )\n\n        # 2. Inject the node ID into SampleInfo for subsequent filtering\n        #    Only inject if node_id has not been manually set, to facilitate testing.\n        if sample_info.node_id is None:\n            sample_info.node_id = caller_node_id\n\n        # 3. Register the metadata and reference to the global queue + update statistics\n        async with self.lock:\n            self._sample_queue.append((sample_info, sample_ref))\n            # Update statistics (single sample put)\n            self._stats_batches_received += 1\n            self._stats_samples_received += 1\n            if self._stats_accumulation_start is None:\n                self._stats_accumulation_start = time.time()\n\n    async def put_batch(self, sample_infos: List[SampleInfo], sample_refs: List[ray.ObjectRef], caller_node_id: Optional[str] = None):\n        \"\"\"\n        Called by a worker to register a batch of new sample references and their metadata.\n        This method routes the ObjectRefs to DataBuffers on their local nodes.\n        \n        Args:\n            sample_infos: List of metadata for each sample\n            sample_refs: List of Ray ObjectRefs\n            caller_node_id: The node ID of the caller. If None, will try to get it from\n                          the runtime context (but this won't work correctly for remote calls)\n        \"\"\"\n        if not sample_refs:\n            return\n\n        # Get the node ID of the caller\n        # Note: When called remotely, ray.get_runtime_context().get_node_id() returns\n        # the node ID of the DataCoordinator actor, not the caller. So we require the\n        # caller to pass their node_id explicitly.\n        if caller_node_id is None:\n            caller_node_id = ray.get_runtime_context().get_node_id()\n            loguru.logger.warning(\n                \"DataCoordinator.put_batch() called without caller_node_id. \"\n                f\"Using DataCoordinator's node_id {caller_node_id[:16]}... which may be incorrect.\"\n            )\n\n        for i in range(len(sample_infos)):\n            if sample_infos[i].node_id is None:\n                sample_infos[i].node_id = caller_node_id\n        \n        async with self.lock:\n            self._sample_queue.extend(zip(sample_infos, sample_refs))\n            \n            # Update statistics\n            self._stats_batches_received += 1\n            self._stats_samples_received += len(sample_refs)\n            if self._stats_accumulation_start is None:\n                self._stats_accumulation_start = time.time()\n\n    async def get_batch(\n        self, \n        batch_size: int, \n        dp_rank: int, \n        filter_plugin: Optional[Callable[[SampleInfo], bool]] = None,\n        balance_partitions: Optional[int] = None,\n        cache_key: Optional[str] = None\n    ) -> List[ray.ObjectRef]:\n        \"\"\"Called by a Trainer to get a batch of sample ObjectRefs.\n        \n        Supports caching for multi-rank access within the same node, filter plugins\n        for custom sampling logic, and length balancing across partitions.\n        \n        Args:\n            batch_size: The requested batch size per partition.\n            dp_rank: The data parallel rank requesting data.\n            filter_plugin: Optional filter function for custom sampling logic.\n            balance_partitions: Number of partitions for length balancing (typically dp_size).\n            cache_key: Key to identify the cache (e.g., node name like 'compute_reward').\n                       Different keys invalidate the cache to prevent data mixing.\n        \n        Returns:\n            A list of sample ObjectRefs for the specified dp_rank.\n        \"\"\"\n        async with self.lock:\n            global_batch_size = batch_size * balance_partitions\n            \n            # === Phase 1: Check cache validity ===\n            if self._cache:\n                if cache_key is not None and self._cache_key == cache_key:\n                    # Same key: return cached data (allows multiple reads by tp/pp ranks)\n                    if dp_rank < len(self._cache):\n                        return self._cache[dp_rank]\n                    return []\n                # Different key: invalidate old cache\n                self._cache = []\n                self._cache_key = None\n            \n            # === Phase 2: Fetch data from queue ===\n            if filter_plugin:\n                # With filter: O(N) scan\n                if isinstance(filter_plugin, list):\n                    batch_items = [item for item in self._sample_queue \n                                   if all(f(item[0]) for f in filter_plugin)]\n                else:\n                    batch_items = [item for item in self._sample_queue \n                                   if filter_plugin(item[0])]\n                \n                if len(batch_items) < global_batch_size:\n                    self._log_accumulation_progress(len(batch_items), global_batch_size)\n                    return []\n                \n                # Take only what we need and remove from queue\n                batch_items = batch_items[:global_batch_size]\n                refs_to_remove = {item[1] for item in batch_items}\n                self._sample_queue = deque(\n                    item for item in self._sample_queue if item[1] not in refs_to_remove\n                )\n            else:\n                # No filter: efficient FIFO\n                if len(self._sample_queue) < global_batch_size:\n                    self._log_accumulation_progress(len(self._sample_queue), global_batch_size)\n                    return []\n                \n                batch_items = [self._sample_queue.popleft() for _ in range(global_batch_size)]\n            \n            # === Phase 3: Apply length balancing ===\n            if balance_partitions and balance_partitions > 1:\n                batch_refs = self._apply_length_balancing(batch_items, balance_partitions)\n            else:\n                batch_refs = [item[1] for item in batch_items]\n            \n            # === Phase 4: Build cache (unified nested list structure) ===\n            self._cache = [\n                batch_refs[rank * batch_size:(rank + 1) * batch_size]\n                for rank in range(balance_partitions)\n            ]\n            self._cache_key = cache_key\n            \n            self._log_dispatch_stats(global_batch_size)\n            return self._cache[dp_rank]\n\n    def _log_accumulation_progress(self, current_samples: int, target_samples: int):\n        \"\"\"Log progress milestones at INFO level when reaching 25%, 50%, 75%.\"\"\"\n        if target_samples <= 0:\n            return\n            \n        current_pct = int(current_samples * 100 / target_samples)\n        \n        # Log at 25%, 50%, 75% milestones (only once per milestone)\n        # Log the highest crossed milestone that hasn't been logged yet\n        milestones = [25, 50, 75]\n        highest_crossed = None\n        for milestone in milestones:\n            if current_pct >= milestone and self._stats_last_progress_pct < milestone:\n                highest_crossed = milestone\n        \n        if highest_crossed is not None:\n            wait_time = time.time() - self._stats_accumulation_start if self._stats_accumulation_start else 0\n            loguru.logger.info(\n                f\"[DataCoordinator] Accumulation {highest_crossed}%: {current_samples}/{target_samples} samples \"\n                f\"({self._stats_batches_received} batches, {wait_time:.1f}s)\"\n            )\n            self._stats_last_progress_pct = highest_crossed\n    \n    def _log_dispatch_stats(self, dispatched_samples: int):\n        \"\"\"Log statistics when dispatching a batch and reset counters.\"\"\"\n        wait_time = time.time() - self._stats_accumulation_start if self._stats_accumulation_start else 0\n\n        avg_samples_per_batch = (\n            self._stats_samples_received / self._stats_batches_received\n            if self._stats_batches_received > 0 else 0\n        )\n\n        total_received = self._stats_samples_received\n        remaining_in_queue = total_received - dispatched_samples\n\n        loguru.logger.info(\n            f\"[DataCoordinator DISPATCH] \"\n            f\"Accumulated: {total_received} samples from {self._stats_batches_received} batches | \"\n            f\"Dispatching: {dispatched_samples} samples | \"\n            f\"Remaining in queue: {remaining_in_queue} | \"\n            f\"Avg per batch: {avg_samples_per_batch:.1f} | \"\n            f\"Wait: {wait_time:.1f}s\"\n        )\n\n        self._stats_batches_received = 0\n        self._stats_samples_received = 0\n        self._stats_accumulation_start = None\n        self._stats_last_progress_pct = 0\n    \n    def _apply_length_balancing(\n        self, \n        batch_items: List[Tuple[SampleInfo, ray.ObjectRef]], \n        k_partitions: int,\n        keep_mini_batch = False\n    ) -> List[ray.ObjectRef]:\n        \"\"\"Applies the length balancing algorithm to reorder samples.\n        \n        Uses the LPT (Longest Processing Time) algorithm to reorder samples so that\n        if they are evenly distributed among k_partitions workers, the sum of\n        sample lengths for each worker is as balanced as possible.\n        \n        Supports Group N: samples with the same uid will be assigned to the same partition,\n        ensuring correct group-relative advantage computation for GRPO and similar algorithms.\n        \n        Args:\n            batch_items: A list of (SampleInfo, ObjectRef) tuples.\n            k_partitions: The number of partitions (typically the DP size).\n            keep_mini_batch: Whether to keep mini-batch structure during balancing.\n            \n        Returns:\n            A reordered list of ObjectRefs.\n        \"\"\"\n        # ========== Step 1: Group samples by uid ==========\n        uid_to_indices = defaultdict(list)\n        for idx, (sample_info, _) in enumerate(batch_items):\n            uid = sample_info.uid if sample_info.uid is not None else str(idx)\n            uid_to_indices[uid].append(idx)\n\n        # Check if grouping is needed (max_group_size > 1 means we have Group N)\n        max_group_size = max(len(indices) for indices in uid_to_indices.values()) if uid_to_indices else 1\n\n        if max_group_size == 1:\n            # No grouping needed, use original single-sample balancing logic\n            return self._apply_length_balancing_single_sample(batch_items, k_partitions, keep_mini_batch)\n\n        # ========== Step 2: Calculate workload for each Group ==========\n        group_list = list(uid_to_indices.keys())  # All unique uids\n        group_workloads = []\n        for uid in group_list:\n            indices = uid_to_indices[uid]\n            # Group workload = sum of all samples' sum_tokens in the group\n            total_tokens = sum(batch_items[i][0].sum_tokens for i in indices)\n            group_workloads.append(total_tokens)\n\n        # ========== Step 3: Balance Groups across partitions ==========\n        workload_lst = calculate_workload(group_workloads)\n\n        # Check if number of groups is divisible by k_partitions\n        num_groups = len(group_list)\n        if num_groups < k_partitions:\n            loguru.logger.warning(\n                f\"Number of groups ({num_groups}) is less than partitions ({k_partitions}). \"\n                f\"Some partitions will be empty. Falling back to single-sample balancing.\"\n            )\n            return self._apply_length_balancing_single_sample(batch_items, k_partitions, keep_mini_batch)\n\n        equal_size = num_groups % k_partitions == 0\n        if not equal_size:\n            loguru.logger.warning(\n                f\"Number of groups ({num_groups}) is not divisible by partitions ({k_partitions}). \"\n                f\"Some partitions may have uneven group counts.\"\n            )\n\n        # Partition groups across workers\n        group_partitions = get_seqlen_balanced_partitions(workload_lst, k_partitions=k_partitions, equal_size=equal_size)\n\n        # ========== Step 4: Expand groups to samples, keeping group integrity ==========\n        reordered_refs = []\n        for partition_group_indices in group_partitions:\n            for group_idx in partition_group_indices:\n                uid = group_list[group_idx]\n                sample_indices = uid_to_indices[uid]\n                # Add all samples of the same group together, preserving original order within group\n                for sample_idx in sample_indices:\n                    reordered_refs.append(batch_items[sample_idx][1])\n\n        loguru.logger.debug(\n            f\"Applied GROUP-aware length balancing: \"\n            f\"{len(batch_items)} samples in {num_groups} groups (group_size={max_group_size}) \"\n            f\"reordered into {k_partitions} partitions\"\n        )\n\n        return reordered_refs\n\n    def _apply_length_balancing_single_sample(\n        self,\n        batch_items: List[Tuple[SampleInfo, ray.ObjectRef]],\n        k_partitions: int,\n        keep_mini_batch=False,\n    ) -> List[ray.ObjectRef]:\n        \"\"\"Original length balancing logic for single samples (no UID grouping).\n        \n        This is used when there's no Group N (each uid has only one sample).\n        \n        Args:\n            batch_items: A list of (SampleInfo, ObjectRef) tuples.\n            k_partitions: The number of partitions (typically the DP size).\n            keep_mini_batch: Whether to keep mini-batch structure during balancing.\n            \n        Returns:\n            A reordered list of ObjectRefs.\n        \"\"\"\n        # Extract the length of each sample.\n        # Use sum_tokens as the length metric (includes prompt + response).\n        seqlen_list = [item[0].sum_tokens for item in batch_items]\n        \n        # Use the karmarkar_karp balance\n        workload_lst = calculate_workload(seqlen_list)\n        # Decouple the DP balancing and mini-batching.\n        if keep_mini_batch:\n            minibatch_size = self.ppo_mini_batch_size\n            minibatch_num = len(workload_lst) // minibatch_size\n            global_partition_lst = [[] for _ in range(self.world_size)]\n            for i in range(minibatch_num):\n                rearrange_minibatch_lst = get_seqlen_balanced_partitions(\n                    workload_lst[i * minibatch_size : (i + 1) * minibatch_size],\n                    k_partitions=self.world_size,\n                    equal_size=True,\n                )\n                for j, part in enumerate(rearrange_minibatch_lst):\n                    global_partition_lst[j].extend([x + minibatch_size * i for x in part])\n        else:\n            global_partition_lst = get_seqlen_balanced_partitions(\n                workload_lst, k_partitions=self.world_size, equal_size=True\n            )    \n            \n            \n        # Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel.\n        for idx, partition in enumerate(global_partition_lst):\n            partition.sort(key=lambda x: (workload_lst[x], x))\n            ordered_partition = partition[::2] + partition[1::2][::-1]\n            global_partition_lst[idx] = ordered_partition\n        \n        # Reorder the samples based on the partitioning result.\n        # Concatenate the partitions in order: [all samples from partition_0, all from partition_1, ...]\n        reordered_refs = []\n        for partition in global_partition_lst:\n            for original_idx in partition:\n                reordered_refs.append(batch_items[original_idx][1])\n        \n        loguru.logger.debug(\n            f\"Applied length balancing: {len(batch_items)} samples reordered into {k_partitions} partitions\"\n        )\n        \n        return reordered_refs\n        \n\n    async def get_all_by_filter(self, filter_plugin: Callable[[SampleInfo], bool]) -> List[ray.ObjectRef]:\n        \"\"\"\n        Gets ALL sample ObjectRefs that match the filter plugin, consuming them from the queue.\n        This is useful for pipeline-based data passing where a downstream stage needs the\n        entire output of an upstream stage.\n        \"\"\"\n        async with self.lock:\n            # 1. Find all items that match the filter.\n            items_to_return = [item for item in self._sample_queue if filter_plugin(item[0])]\n            \n            if not items_to_return:\n                return []\n\n            # 2. Extract their ObjectRefs.\n            batch_refs = [item[1] for item in items_to_return]\n\n            # 3. Efficiently remove the selected items from the original queue.\n            refs_to_remove = {ref for ref in batch_refs}\n            self._sample_queue = deque(item for item in self._sample_queue if item[1] not in refs_to_remove)\n            \n            return batch_refs\n\n    async def get_valid_size(self) -> int:\n        \"\"\"Returns the number of samples in the current queue.\"\"\"\n        async with self.lock:\n            return len(self._sample_queue)\n    \n    async def peek_source_dp_size(self, filter_plugin: Callable[[SampleInfo], bool]) -> Optional[int]:\n        \"\"\"\n        Peek at the source_dp_size of matching samples without consuming them.\n        \n        Args:\n            filter_plugin: Filter function to find matching samples\n            \n        Returns:\n            The source_dp_size if found, None otherwise\n        \"\"\"\n        async with self.lock:\n            for sample_info, _ in self._sample_queue:\n                if filter_plugin(sample_info):\n                    source_dp_size = sample_info.dict_info.get('source_dp_size')\n                    if source_dp_size is not None:\n                        return source_dp_size\n            return None\n\n    def reset_cache(self):\n        \"\"\"Reset the coordinator state for a new training step.\"\"\"\n        loguru.logger.info(\"Resetting DataCoordinator cache\")\n        self._sample_queue.clear()\n        self._cache = []\n        self._cache_key = None\n\n    def __repr__(self) -> str:\n        return f\"<DataCoordinator(total_samples={len(self._sample_queue)})>\"\n\n\n# ====================================================================\n# Initialization Logic\n# ====================================================================\n\ndef init_data_coordinator(num_buffers: int, ppo_mini_batch_size: int, world_size: int) -> ray.actor.ActorHandle:\n    \"\"\"\n    Initializes the data coordination system, which includes a global DataCoordinator\n    and multiple distributed DataBuffers. Returns a single, unified DataCoordinator\n    handle to the user.\n\n    Args:\n        num_buffers: The number of distributed DataBuffer instances to create,\n                     usually equal to the number of nodes or total GPUs.\n        force_local: If True, forces all Buffers to be created on the local node,\n                     for single-machine testing.\n\n    Returns:\n        The Actor handle for the DataCoordinator.\n    \"\"\"\n    if not ray.is_initialized():\n        raise RuntimeError(\"Ray must be initialized before calling init_data_coordinator.\")\n\n    # 1. Create or get the globally unique DataCoordinator\n    # Use a global name to ensure the coordinator's uniqueness\n    coordinator_name = \"global_data_coordinator\"\n    try:\n        coordinator = ray.get_actor(coordinator_name)\n        loguru.logger.info(f\"Connected to existing DataCoordinator actor '{coordinator_name}'.\")\n    except ValueError:\n        loguru.logger.info(f\"Creating new DataCoordinator actor with global name '{coordinator_name}'.\")\n        coordinator = DataCoordinator.options(name=coordinator_name, lifetime=\"detached\").remote(nnodes=num_buffers, ppo_mini_batch_size=ppo_mini_batch_size, world_size=world_size)\n   \n    return coordinator\n"
  },
  {
    "path": "siirl/data_coordinator/dataloader/__init__.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .data_loader_node import DataLoaderNode\n\n__all__ = [\"DataLoaderNode\"]\n"
  },
  {
    "path": "siirl/data_coordinator/dataloader/data_loader_node.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Any, Dict, Iterator, Optional\n\nimport torch\nfrom loguru import logger\nfrom torch.utils.data import RandomSampler, SequentialSampler\nfrom torchdata.stateful_dataloader import StatefulDataLoader\n\nfrom siirl.execution.dag import Node, NodeRole, NodeStatus, NodeType\nfrom siirl.models.loader import load_tokenizer\nfrom siirl.params import SiiRLArguments\n\nfrom siirl.data_coordinator.dataloader.partitioned_dataset import PartitionedRLHFDataset\n\n\nclass RepeatDataset(torch.utils.data.Dataset):\n    \"\"\"\n    A dataset wrapper that repeats the base dataset multiple times.\n\n    This class allows you to virtually extend the length of a given dataset by repeating its samples\n    a specified number of times. It is useful for scenarios where you want to train for more epochs\n    without reloading or reshuffling the data, or to balance datasets by oversampling.\n\n    Args:\n        base_dataset (torch.utils.data.Dataset): The original dataset to be repeated.\n        repeat_factor (int): The number of times to repeat the base dataset.\n\n    Attributes:\n        base_dataset (torch.utils.data.Dataset): The original dataset.\n        repeat_factor (int): The number of repetitions.\n        length (int): The total length of the repeated dataset.\n\n    Example:\n        >>> base_dataset = MyCustomDataset()\n        >>> repeated_dataset = RepeatDataset(base_dataset, repeat_factor=3)\n        >>> len(repeated_dataset) == 3 * len(base_dataset)\n        True\n\n    \"\"\"\n\n    def __init__(self, base_dataset, repeat_factor):\n        self.base_dataset = base_dataset\n        self.repeat_factor = repeat_factor\n        self.length = len(base_dataset) * repeat_factor\n\n    def __len__(self):\n        return self.length\n\n    def __getitem__(self, idx):\n        return self.base_dataset[idx % len(self.base_dataset)]\n\n\nclass DataLoaderNode(Node):\n    \"\"\"\n    Represents a data loader node in the DAG.\n    This version uses the PartitionedRLHFDataset for efficient, memory-safe\n    distributed data loading. Each rank only loads and processes its own data slice.\n    \"\"\"\n\n    def __init__(\n        self, node_id: str, global_config: SiiRLArguments, config: Optional[Dict[str, Any]] = None, retry_limit: int = 0\n    ):\n        \"\"\"\n        Initialize a data loader node.\n\n        Args:\n            node_id (str): The unique identifier of the node.\n            global_config(SiiRLArguments): The arguments from config file.\n            config (Optional[Dict[str, Any]]): Specific configuration information for the node. Defaults to an empty dictionary.\n            retry_limit (int): The maximum number of retries when the node execution fails. Defaults to 0 (no retries).\n        \"\"\"\n        super().__init__(node_id, NodeType.DATA_LOAD, NodeRole.DEFAULT, config=config, retry_limit=retry_limit)\n        self.global_config = global_config\n\n        if \"tokenizer\" in self.config:\n            self.tokenizer = self.config[\"tokenizer\"]\n            self.processor = self.config[\"processor\"]\n        else:\n            # Load tokenizer and processor\n            tokenizer_module = load_tokenizer(model_args=global_config.actor_rollout_ref.model)\n            self.tokenizer = tokenizer_module[\"tokenizer\"]\n            self.processor = tokenizer_module[\"processor\"]\n\n        # force load in main process for vision language model\n        self.num_loader_workers = (\n            0\n            if global_config.actor_rollout_ref.rollout.name == \"sglang\" or self.processor is not None\n            else config.get(\"num_loader_workers\", 8)\n        )\n\n        # Get group world size, rank, parallel size from config.\n        #   Group world size means the rollout pytorch distributed group total gpus.\n        #   Group rank means the process index in distributed group.\n        #   Group parallel size means the rollout total parallel size, e.g. tp_size * pp_size\n        self.group_world_size = config[\"group_world_size\"]\n        self.group_rank = config[\"group_rank\"]\n        self.group_parallel_size = config[\"group_parallel_size\"]\n        if self.group_world_size % self.group_parallel_size != 0:\n            # Log an error or raise an exception if world_size is not divisible by group_parallel_size\n            error_msg = f\"group_world_size ({self.group_world_size}) must be divisible by group_parallel_size ({self.group_parallel_size}).\"\n            logger.error(error_msg)\n            raise ValueError(error_msg)\n        # Calculate the world size and rank for rollout data parallelism, which is actually needed for data partitioning.\n        self.rollout_ddp_world_size = self.group_world_size // self.group_parallel_size\n        self.rollout_ddp_rank = self.group_rank // self.group_parallel_size\n\n        self._current_train_iter: Optional[Iterator] = None\n        self._current_val_iter: Optional[Iterator] = None\n        self._current_epoch: int = -1\n\n        self._create_dataloader()\n\n        self.num_train_batches = len(self.train_dataloader) if self.train_dataloader else 0\n        self.num_val_batches = len(self.val_dataloader) if self.val_dataloader else 0\n\n        logger.info(f\"DataLoaderNode '{self.node_id}' initialized:\")\n        logger.info(f\"  Group rank: {self.group_rank} / {self.group_world_size}\")\n        logger.info(f\"  Rollout DDP rank: {self.rollout_ddp_rank} / {self.rollout_ddp_world_size}\")\n        logger.info(f\"  Train batches per epoch for this rank: {self.num_train_batches}\")\n        logger.info(f\"  Total training steps (approx): {self.total_training_steps}\")\n\n    def _create_dataloader(self):\n        \"\"\"\n        Initializes and configures the training and validation DataLoaders for RLHF tasks.\n\n        When enable `auto_repeat`, if the dataset is too small to form a batch, it will be automatically repeated\n        until at least one batch can be formed.\n\n        This method performs the following steps:\n        1. Creates the training dataset (`PartitionedRLHFDataset`) with the provided configuration, tokenizer, processor, and distributed data parallel (DDP) settings.\n        2. Sets up the sampler for the training DataLoader:\n            - Uses a `RandomSampler` with a seeded generator if shuffling is enabled in the configuration.\n            - Uses a `SequentialSampler` otherwise.\n        3. Configures the tokenizer's padding side to \"left\" to ensure correct sequence alignment.\n        4. Creates the training DataLoader (`StatefulDataLoader`) with the specified batch size, number of workers, sampler, and collator.\n        5. Creates the validation dataset and DataLoader, using the full dataset as a single batch for evaluation.\n        6. Asserts that the training DataLoader contains at least one batch.\n        7. Calculates the total number of training steps based on the number of batches and epochs, or uses a user-specified value if provided.\n        8. Updates the total training steps in the optimizer configurations for both the actor and critic components.\n        \"\"\"\n        # Create the partitioned training dataset for this rank\n        self.train_dataset = PartitionedRLHFDataset(\n            config=self.global_config,\n            tokenizer=self.tokenizer,\n            processor=self.processor,\n            ddp_rank=self.rollout_ddp_rank,\n            ddp_world_size=self.rollout_ddp_world_size,\n            is_eval=False,\n            drop_last=self.config.get(\"train_drop_last\", True),\n        )\n\n        # Calculate batch size per rank\n        train_batch_size = self.global_config.data.train_batch_size // self.rollout_ddp_world_size\n        self.train_batch_size = train_batch_size\n        # Auto-repeat logic: if dataset is too small, repeat it until at least one batch can be formed\n        auto_repeat = self.config.get(\"auto_repeat\", False)\n        train_len = len(self.train_dataset)\n        if auto_repeat and train_len < train_batch_size:\n            repeat_factor = (train_batch_size + train_len - 1) // train_len\n\n            self.train_dataset = RepeatDataset(self.train_dataset, repeat_factor)\n            logger.warning(\n                f\"Rank {self.rollout_ddp_rank}: Training dataset too small (size={train_len}), \"\n                f\"auto-repeating {repeat_factor} times to ensure at least one batch (batch_size={train_batch_size}). \"\n                f\"Now RepeatDataset size={len(self.train_dataset)}\"\n            )\n\n        # Choose sampler: RandomSampler with seed if shuffle enabled, else SequentialSampler\n        if self.global_config.data.shuffle:\n            train_dataloader_generator = torch.Generator()\n            train_dataloader_generator.manual_seed(self.global_config.trainer.seed)\n            sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)\n        else:\n            sampler = SequentialSampler(data_source=self.train_dataset)\n\n        # Create the training dataloader with the specified batch size, workers, sampler, and collator\n        from siirl.data_coordinator.dataloader.partitioned_dataset import collate_fn as default_collate_fn\n\n        self.train_dataloader = StatefulDataLoader(\n            dataset=self.train_dataset,\n            batch_size=train_batch_size,\n            num_workers=self.num_loader_workers,\n            drop_last=True,\n            collate_fn=default_collate_fn,\n            sampler=sampler,\n        )\n\n        # Create the partitioned validation dataset for this rank\n        self.val_dataset = PartitionedRLHFDataset(\n            config=self.global_config,\n            tokenizer=self.tokenizer,\n            processor=self.processor,\n            ddp_rank=self.rollout_ddp_rank,\n            ddp_world_size=self.rollout_ddp_world_size,\n            is_eval=True,\n            drop_last=self.config.get(\"eval_drop_last\", False),\n        )\n\n        # Create the validation dataloader, loading the entire validation set as one batch\n        val_batch_size = self.global_config.data.val_batch_size  # Prefer config value if set\n        if val_batch_size is None:\n            val_batch_size = len(self.val_dataset)\n        self.val_dataloader = StatefulDataLoader(\n            dataset=self.val_dataset,\n            batch_size=val_batch_size,\n            num_workers=self.num_loader_workers,\n            shuffle=False,\n            drop_last=False,\n            collate_fn=default_collate_fn,\n        )\n\n        # Assert that there is at least one batch for this rank\n        assert (\n            len(self.train_dataloader) >= 1\n        ), f\"Not enough data for current rank (rank id: {self.rollout_ddp_rank}) to consume. Please increase the train datasets or reduce the number of GPUs.\"\n        assert len(self.val_dataloader) >= 1, \"Validation dataloader is empty!\"\n        # Calculate the number of batches and total training steps\n        num_batches = len(self.train_dataloader) if self.train_dataloader else 0\n        total_training_steps = num_batches * self.global_config.trainer.total_epochs\n        # Use user-specified total_training_steps if provided\n        if self.global_config.trainer.total_training_steps is not None:\n            total_training_steps = self.global_config.trainer.total_training_steps\n\n        self.total_training_steps = total_training_steps\n\n        # Update total training steps in optimizer configs for actor and critic\n        self.global_config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps\n        self.global_config.critic.optim.total_training_steps = total_training_steps\n\n        # Indicates the samples for this rank has already been expand\n        self.is_val_trailing_rank = self.val_dataset.is_trailing_rank\n\n    def _reinit_dataloader_sampler(self):\n        \"\"\"\n        Re-initializes the sampler and dataloader to clear any internal state (like being exhausted).\n        This is useful when resuming from a checkpoint that was saved at the end of an epoch.\n        \"\"\"\n        # Re-create the sampler\n        if self.global_config.data.shuffle:\n            train_dataloader_generator = torch.Generator()\n            train_dataloader_generator.manual_seed(self.global_config.trainer.seed)\n            sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)\n        else:\n            sampler = SequentialSampler(data_source=self.train_dataset)\n\n        # Re-create the dataloader with the new sampler\n        from siirl.data_coordinator.dataloader.partitioned_dataset import collate_fn as default_collate_fn\n\n        train_batch_size = self.global_config.data.train_batch_size // self.rollout_ddp_world_size\n\n        self.train_dataloader = StatefulDataLoader(\n            dataset=self.train_dataset,\n            batch_size=train_batch_size,\n            num_workers=self.num_loader_workers,\n            drop_last=True,\n            collate_fn=default_collate_fn,\n            sampler=sampler,\n        )\n        logger.info(f\"Node {self.node_id}: Re-initialized dataloader and sampler.\")\n\n    def get_train_dataloader(self):\n        \"\"\"\n        Returns the training data loader.\n\n        Returns:\n            DataLoader: The data loader used for training.\n        \"\"\"\n        return self.train_dataloader\n\n    def get_val_dataloader(self):\n        \"\"\"\n        Returns the validation dataloader.\n\n        Returns:\n            DataLoader: The dataloader used for validation data.\n        \"\"\"\n        return self.val_dataloader\n\n    def run(self, epoch: Optional[int] = None, is_validation_step: bool = False, **kwargs: Any) -> Any:\n        \"\"\"\n        Executes the data loading process for a given step or validation.\n\n        Args:\n            epoch (Optional[int]): The current epoch number. Required for training steps to handle\n                                   sampler state (e.g., DistributedSampler.set_epoch()).\n            is_validation_step (bool): Flag indicating if validation data is requested.\n            **kwargs: Additional arguments (not used directly in this basic version but\n                      part of the Node.execute signature).\n\n        Returns:\n            Any: A batch of data. The structure depends on the collate_fn.\n\n        Raises:\n            ValueError: If epoch is not provided for a training step.\n            StopIteration: If the dataloader is exhausted and cannot provide more data\n                           (though this might be handled by the DAG scheduler).\n        \"\"\"\n        self.update_status(NodeStatus.RUNNING)\n        logger.debug(f\"Node {self.node_id} execute: epoch={epoch}, is_validation_step={is_validation_step}\")\n\n        try:\n            if is_validation_step:\n                if not self.val_dataloader:  # Handles empty validation dataset\n                    logger.warning(f\"Rank {self.group_rank}: Validation dataloader is not available or empty.\")\n                    self.update_status(NodeStatus.COMPLETED)  # Or FAILED if this is an error condition\n                    return None  # Or an empty batch marker\n\n                # Validation dataloader loads the entire validation set as one batch.\n                # We get a fresh iterator each time for validation.\n                if self._current_val_iter is None:\n                    self._current_val_iter = iter(self.val_dataloader)\n\n                try:\n                    batch = next(self._current_val_iter)\n                    logger.debug(f\"Node {self.node_id}: Yielding validation batch.\")\n                    # Reset for next validation call, as it's one batch\n                    self._current_val_iter = None\n                except StopIteration:\n                    logger.warning(\n                        f\"Node {self.node_id}: Validation dataloader exhausted unexpectedly (should be one batch). Resetting.\"\n                    )\n                    # This case should ideally not happen if batch_size = len(dataset) and it's not empty\n                    self._current_val_iter = iter(self.val_dataloader)  # Get a fresh iterator\n                    try:\n                        batch = next(self._current_val_iter)\n                    except StopIteration:\n                        logger.error(f\"Node {self.node_id}: Validation dataloader is empty even after reset.\")\n                        self.update_status(NodeStatus.FAILED, \"Validation dataloader empty\")\n                        return None\n            else:  # Training step\n                if epoch is None:\n                    error_msg = \"Epoch must be provided for training steps.\"\n                    logger.error(f\"Node {self.node_id}: {error_msg}\")\n                    self.update_status(NodeStatus.FAILED, error_msg)\n                    raise ValueError(error_msg)\n\n                if not self.train_dataloader:  # Handles empty training dataset\n                    logger.warning(f\"Rank {self.group_rank}: Training dataloader is not available or empty.\")\n                    self.update_status(NodeStatus.COMPLETED)  # Or FAILED\n                    return None  # Or an empty batch marker\n\n                # Flag to track if we just created the iterator\n                iterator_just_created = False\n                if self._current_epoch != epoch or self._current_train_iter is None:\n                    logger.info(f\"Node {self.node_id}: New epoch ({epoch}) or first step. Initializing train iterator.\")\n                    self._current_epoch = epoch\n                    # Set epoch for DistributedSampler if applicable\n                    if hasattr(self.train_dataloader.sampler, \"set_epoch\") and isinstance(\n                        self.train_dataloader.sampler, DistributedSampler\n                    ):\n                        logger.debug(f\"Node {self.node_id}: Setting epoch {epoch} for DistributedSampler.\")\n                        self.train_dataloader.sampler.set_epoch(epoch)\n\n                    self._current_train_iter = iter(self.train_dataloader)\n                    iterator_just_created = True\n\n                try:\n                    batch = next(self._current_train_iter)\n                    logger.debug(f\"Node {self.node_id}: Yielding training batch for epoch {epoch}.\")\n                except StopIteration:\n                    # FIX: Handle resume from end-of-epoch state\n                    if iterator_just_created:\n                        logger.warning(\n                            f\"Node {self.node_id}: Iterator exhausted immediately after creation. \"\n                            f\"This indicates resumption from a completed epoch state. \"\n                            f\"Resetting dataloader to start fresh for epoch {epoch}.\"\n                        )\n                        # Re-create the dataloader to clear the internal \"exhausted\" state\n                        # We keep the dataset to avoid reloading heavy data\n                        self._reinit_dataloader_sampler()\n                        self._current_train_iter = iter(self.train_dataloader)\n                        batch = next(self._current_train_iter)\n                    else:\n                        # Real end of epoch\n                        error_msg = (\n                            f\"Training dataloader exhausted for epoch {epoch}. This might be expected at epoch end.\"\n                        )\n                        logger.info(f\"Node {self.node_id}: {error_msg}\")\n                        # We might not want to mark FAILED here, as it's a natural end of an iterator.\n                        # The caller (DAG executor) should decide if more data was expected.\n                        # For now, let's re-raise StopIteration to signal the caller.\n                        self.update_status(NodeStatus.COMPLETED)  # Or a custom status like 'EPOCH_END'\n                        raise  # Re-raise StopIteration\n\n            self.update_status(NodeStatus.COMPLETED)\n            return batch\n\n        except Exception as e:\n            error_msg = f\"Error during data loading in node {self.node_id}: {e}\"\n            logger.exception(error_msg)  # Log with stack trace\n            self.update_status(NodeStatus.FAILED, str(e))\n            raise  # Re-raise the exception so the DAG executor can handle it\n\n    def state_dict(self) -> Dict[str, Any]:\n        \"\"\"\n        Captures the state of the DataLoaderNode, primarily the training dataloader's state.\n\n        Returns:\n            Dict[str, Any]: A dictionary containing the node's state.\n        \"\"\"\n        return {\n            \"train_dataloader_state\": self.train_dataloader.state_dict(),\n        }\n\n    def load_state_dict(self, state_dict: Dict[str, Any]):\n        \"\"\"\n        Restores the state of the DataLoaderNode from a state dictionary.\n\n        Args:\n            state_dict (Dict[str, Any]): The state dictionary to load.\n        \"\"\"\n        if \"train_dataloader_state\" in state_dict:\n            self.train_dataloader.load_state_dict(state_dict[\"train_dataloader_state\"])\n            # After loading state, the current iterator is invalid because it's tied to the old\n            # sampler state. Setting it to None forces the run() method to create a new,\n            # valid iterator that is synchronized with the restored state.\n            self._current_train_iter = None\n            logger.info(\n                f\"Node {self.node_id} (Rank {self.group_rank}): Successfully loaded train_dataloader state. Iterator will be reset on next call.\"\n            )\n"
  },
  {
    "path": "siirl/data_coordinator/dataloader/embodied_preprocess.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom pathlib import Path\nfrom typing import Tuple\n\nimport pandas as pd\nfrom libero.libero import benchmark\nfrom loguru import logger\n\n\ndef prepare_libero_train_valid_datasets(\n    task_suite_name: str,\n    num_trials_per_task: int,\n    dataset_dir: str,\n) -> Tuple[Path, Path]:\n    \"\"\"\n    Generates identical training and validation dataset manifests for a LIBERO task suite.\n    This manifest contains the necessary metadata for the VLA agent to initialize environments.\n\n    The function queries the actual number of initial states for each task and generates\n    trial_ids up to min(actual_initial_states, num_trials_per_task) to ensure all\n    generated (task_id, trial_id) pairs are valid.\n\n    Args:\n        task_suite_name (str): The name of the task suite.\n        num_trials_per_task (int): The maximum number of trials to include for each task.\n                                   Actual trials may be less if a task has fewer initial states.\n        dataset_dir (str): The directory where the manifest files will be saved.\n\n    Examples:\n        - Task with 50 initial states, num_trials_per_task=100 → generates trial_ids 0-49\n        - Task with 150 initial states, num_trials_per_task=100 → generates trial_ids 0-99\n        - Task with 150 initial states, num_trials_per_task=200 → generates trial_ids 0-149\n    \"\"\"\n    # 1. --- Validate parameters and prepare output directory ---\n    if num_trials_per_task < 1:\n        raise ValueError(\"Number of trials per task must be at least 1\")\n\n    output_path = Path(dataset_dir)\n    output_path.mkdir(parents=True, exist_ok=True)\n\n    train_file = output_path / \"train.parquet\"\n    valid_file = output_path / \"validate.parquet\"\n\n    # 2. --- Get task info from LIBERO benchmark ---\n    try:\n        task_suite = benchmark.get_benchmark_dict()[task_suite_name]()\n        num_tasks_in_suite = task_suite.n_tasks\n    except KeyError as err:\n        raise ValueError(\n            f\"Task suite '{task_suite_name}' not found in benchmark.\"\n        ) from err\n\n    logger.info(f\"Found {num_tasks_in_suite} tasks in '{task_suite_name}'.\")\n    logger.info(\n        f\"Requested maximum of {num_trials_per_task} trials per task.\"\n    )\n\n    # 3. --- Query actual initial state counts for each task ---\n    logger.info(\"Querying initial state counts for each task...\")\n    task_initial_state_counts = []\n    for task_id in range(num_tasks_in_suite):\n        initial_states = task_suite.get_task_init_states(task_id)\n        num_initial_states = len(initial_states)\n        task_initial_state_counts.append(num_initial_states)\n        logger.debug(\n            f\"Task {task_id}: {num_initial_states} initial states available\"\n        )\n\n    # 4. --- Generate records with per-task trial_id limits ---\n    logger.info(\"Generating dataset records with per-task limits...\")\n    all_records = []\n    total_capped_tasks = 0\n    \n    for task_id in range(num_tasks_in_suite):\n        actual_num_states = task_initial_state_counts[task_id]\n        # Cap at actual available initial states\n        max_trials_for_task = min(actual_num_states, num_trials_per_task)\n        \n        # Log if this task is being capped\n        if max_trials_for_task < num_trials_per_task:\n            logger.info(\n                f\"Task {task_id}: Capping at {max_trials_for_task} trials \"\n                f\"(has {actual_num_states} initial states, requested {num_trials_per_task})\"\n            )\n            total_capped_tasks += 1\n        \n        # Generate records for this task\n        for trial_id in range(max_trials_for_task):\n            all_records.append({\n                \"task_suite_name\": task_suite_name,\n                \"task_id\": task_id,\n                \"trial_id\": trial_id,\n                \"prompt_id\": f\"{task_suite_name}_{task_id}_{trial_id}\",\n            })\n    \n    # Log summary statistics\n    expected_records = num_tasks_in_suite * num_trials_per_task\n    actual_records = len(all_records)\n    logger.info(\n        f\"Generated {actual_records} records \"\n        f\"(expected {expected_records} if all tasks had {num_trials_per_task} states)\"\n    )\n    if total_capped_tasks > 0:\n        logger.info(\n            f\"{total_capped_tasks} out of {num_tasks_in_suite} tasks were capped \"\n            f\"due to insufficient initial states\"\n        )\n\n    # 5. --- Save to both train and validation Parquet files ---\n    try:\n        if all_records:\n            df = pd.DataFrame(all_records)\n            df.to_parquet(train_file, index=False)\n            df.to_parquet(valid_file, index=False)\n            logger.success(\n                f\"✅ VLA task manifests successfully saved to '{output_path}'.\"\n            )\n        else:\n            logger.warning(\"No records were generated for the VLA task manifest.\")\n    except ImportError:\n        logger.error(\n            \"`pandas` and `pyarrow` are required. Please run: `pip install pandas pyarrow`\"\n        )\n        raise\n    except Exception as e:\n        logger.error(f\"Error saving Parquet file for VLA manifest: {e}\")\n        raise\n    return train_file, valid_file\n\n"
  },
  {
    "path": "siirl/data_coordinator/dataloader/partitioned_dataset.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport re\n\nfrom concurrent.futures import ThreadPoolExecutor\nfrom collections import defaultdict\nfrom tqdm import tqdm\nfrom typing import Dict, Optional, Sequence\n\nimport datasets\nimport numpy as np\nimport pyarrow as pa\nimport pyarrow.parquet as pq\nimport torch\nfrom loguru import logger\nfrom torch.utils.data import Dataset\nfrom transformers import PreTrainedTokenizer, ProcessorMixin\n\nimport siirl.utils.model_utils.torch_functional as F\nfrom siirl.utils.model_utils.model import compute_position_id_with_mask\nfrom siirl.params import SiiRLArguments\n\n\ndef collate_fn(data_list: list[dict]) -> dict:\n    \"\"\"\n    Collate a batch of sample dicts into batched tensors and arrays.\n\n    Args:\n        data_list: List of dicts mapping feature names to torch.Tensor or other values.\n\n    Returns:\n        Dict where tensor entries are stacked into a torch.Tensor of shape\n        (batch_size, *dims) and non-tensor entries are converted to\n        np.ndarray of dtype object with shape (batch_size,).\n    \"\"\"\n    tensors = defaultdict(list)\n    non_tensors = defaultdict(list)\n    \n    # Fields that should be converted to tensors for embodied tasks\n    embodied_tensor_fields = {'task_id', 'trial_id'}\n\n    for data in data_list:\n        for key, val in data.items():\n            if isinstance(val, torch.Tensor):\n                tensors[key].append(val)\n            elif key in embodied_tensor_fields and isinstance(val, (int, np.integer)):\n                tensors[key].append(torch.tensor(val, dtype=torch.int64))\n            else:\n                non_tensors[key].append(val)\n\n    for key, val in tensors.items():\n        tensors[key] = torch.stack(val, dim=0)\n\n    for key, val in non_tensors.items():\n        non_tensors[key] = np.array(val, dtype=object)\n\n    return {**tensors, **non_tensors}\n\n\nclass PartitionedRLHFDataset(Dataset):\n    \"\"\"\n    An efficient Dataset class for distributed training. It only load and process\n    the data partition (slice) of the RLHF dataset corresponding to the current DDP rank.\n\n    Args:\n        config (SiiRLArguments): Configuration object containing data and preprocessing arguments.\n        tokenizer (PreTrainedTokenizer): Tokenizer for processing text prompts.\n        processor (Optional[ProcessorMixin]): Optional processor for multi-modal data (e.g., images, videos).\n        ddp_rank (int): The rank of the current process in DDP.\n        ddp_world_size (int): Total number of DDP processes (world size).\n        is_eval (bool): Whether the dataset is for evaluation (True) or training (False).\n        drop_last (Optional[bool]): Whether to drop the last remainder\n            if total rows is not divisible by world size.\n            Defaults to True for training, False for evaluation.\n\n    Notes:\n        - This class is optimized for distributed training scenarios, ensuring each DDP process only\n            loads and processes its own data partition.\n        - Supports multi-modal data (text, images, videos) if a processor is provided.\n        - Handles prompt filtering, truncation, and tokenization according to configuration.\n        - Uses multiprocessing for efficient data preprocessing.\n    \"\"\"\n\n    def __init__(\n        self,\n        config: SiiRLArguments,\n        tokenizer: PreTrainedTokenizer,\n        processor: Optional[ProcessorMixin] = None,\n        ddp_rank: int = 0,\n        ddp_world_size: int = 1,\n        is_eval: bool = False,\n        drop_last: Optional[bool] = None,\n    ):\n        super().__init__()\n        self.tokenizer = tokenizer\n        self.processor = processor\n        self.data_args = config.data\n        self._rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0\n        self.ddp_rank = ddp_rank\n        self.ddp_world_size = ddp_world_size\n        self.is_eval = is_eval\n        self.drop_last = drop_last if drop_last is not None else (not is_eval)\n        \n        # The dataset type determines the processing pipeline.\n        # 'llm': Standard processing for prompt-based datasets.\n        # 'embodied': Minimal processing for embodied task manifests, passing metadata through.\n        self.dataset_type = self.data_args.dataset_type\n        if self._rank == 0:\n            logger.info(f\"Initializing dataset with dataset_type='{self.dataset_type}'.\")\n\n        self.prompt_key = self.data_args.prompt_key\n        self.image_key = self.data_args.image_key\n        self.video_key = self.data_args.video_key\n        self.max_prompt_length = self.data_args.max_prompt_length\n        self.truncation = self.data_args.truncation\n        self.return_raw_chat = self.data_args.return_raw_chat\n        self.return_full_prompt = self.data_args.return_full_prompt\n        self.filter_overlong_prompts = self.data_args.filter_overlong_prompts\n        self.num_workers = self.data_args.preprocessing_num_workers if self.data_args.preprocessing_num_workers else max(1, os.cpu_count() // 8)\n        self.force_on_the_fly = config.data.force_on_the_fly\n        self.image_max_pixels = self.data_args.processor.image_max_pixels\n        self.image_min_pixels = self.data_args.processor.image_min_pixels\n        self.video_max_pixels = self.data_args.processor.video_max_pixels\n        self.video_min_pixels = self.data_args.processor.video_min_pixels\n        self.video_fps = self.data_args.processor.video_fps\n        self.video_maxlen = self.data_args.processor.video_maxlen\n        self.multi_turn = config.actor_rollout_ref.rollout.multi_turn.enable\n\n        self.is_trailing_rank = False  # Indicates trailing ranks that received one less data item in round-robin partitioning.\n\n        if self._rank == 0:\n            logger.debug(f\"Initializing PartitionedRLHFDataset with DDP rank {self.ddp_rank}, world size {self.ddp_world_size}, is_eval={self.is_eval}, drop_last={self.drop_last}\")\n\n            if self.processor is not None:\n                logger.info(f\"Set image_max_pixels={self.image_max_pixels}, image_min_pixels={self.image_min_pixels}, video_max_pixels={self.video_max_pixels}, video_min_pixels={self.video_min_pixels}, you can change these values via data.processor.image_max_pixels, etc.\")\n\n        # 1. Load the raw data partition for the current DDP rank\n        dataset_files = self.data_args.val_files if is_eval else self.data_args.train_files\n        raw_dataframe = self._load_partitioned_raw_data(dataset_files)\n\n        if raw_dataframe is None or len(raw_dataframe) == 0:\n            logger.warning(f\"DDP rank {self.ddp_rank} received no data.\")\n            self.processed_dataframe = None\n            return\n\n        # 2. Filter out prompts that are too long from the loaded partition\n        raw_dataframe = self._filter_overlong_prompts(raw_dataframe)\n        \n        # If this rank received fewer samples due to uneven partitioning, pad by duplicating the last row.\n        if self.is_trailing_rank and raw_dataframe is not None and len(raw_dataframe) > 0:\n            try:\n                # Duplicate the last row to pad the partition\n                last_row = raw_dataframe[-1].copy()\n                if \"extra_info\" in last_row and isinstance(last_row[\"extra_info\"], dict):\n                    last_row[\"extra_info\"][\"padded_duplicate\"] = True\n                else:\n                    last_row[\"extra_info\"] = {\"padded_duplicate\": True}\n                last_row_ds = datasets.Dataset.from_list([last_row])\n                raw_dataframe = datasets.concatenate_datasets([raw_dataframe, last_row_ds])\n                logger.debug(f\"DDP rank {self.ddp_rank} is a trailing rank, duplicating last row to pad partition. New length: {len(raw_dataframe)}\")\n            except Exception:\n                # We can safely ignore this exception because we mainly rely on the 'padded_duplicate' flag to identify padded elements.\n                pass\n        \n        # Based on the dataset type, we either perform full preprocessing or just\n        # pass the raw embodied metadata through.\n        # Initialize load_on_the_fly before branching to ensure it's always set\n        self.load_on_the_fly = self.force_on_the_fly\n        \n        if self.dataset_type == \"embodied\":\n            # For embodied, we do no preprocessing. The dataset is just the raw manifest.\n            # The tokenizer and processor are not used at this stage.\n            self.processed_dataframe = raw_dataframe\n            if self._rank == 0:\n                logger.info(\"Embodied dataset type detected. Skipping tokenization and prompt processing.\")\n        else:\n            # For 'llm' (default), proceed with the original preprocessing logic.\n            # 3. Preprocess the entire partition using multiple processes\n            # By only removing the specific prompt_key, we ensure that other columns,\n            # including complex types like dicts and strings from the original dataset,\n            # are preserved. The .map() function will then add the new columns\n            # returned by _preprocess_function. This is safer than removing all\n            # columns and rebuilding the dataset from scratch.\n            if self.load_on_the_fly:\n                self.processed_dataframe = raw_dataframe\n            else:\n                if self._rank == 0:\n                    logger.warning(\"Currently preloading and preprocessing the entire dataset. If you encounter Out-Of-Memory issues, \"\n                     f\"please set data.force_on_the_fly=True to enable on-the-fly loading mode. \"\n                     f\"Now the dataset type is {self.dataset_type}.\")\n                with ThreadPoolExecutor(max_workers=self.num_workers) as executor:\n                    self.processed_dataframe = list(tqdm(\n                        executor.map(self._preprocess_function, raw_dataframe),\n                        total=len(raw_dataframe),\n                        desc=\"Processing\"\n                    ))\n\n    def _load_partitioned_raw_data(self, dataset_files: Sequence[str]) -> Optional[datasets.Dataset]:\n        \"\"\"\n        Loads a partition of Parquet data for the current DDP rank.\n        \"\"\"\n        if not dataset_files:\n            raise RuntimeError(\"No dataset files configured, aborting...\")\n\n        try:\n            pq_files = [pq.ParquetFile(f) for f in dataset_files]\n\n            # Gather (file_idx, row_group_idx, num_rows, start_row_idx_global)\n            row_group_infos = []\n            total_rows = 0\n            for file_idx, pq_file in enumerate(pq_files):\n                for rg_idx in range(pq_file.num_row_groups):\n                    num_rows = pq_file.metadata.row_group(rg_idx).num_rows\n                    row_group_infos.append({\"file_idx\": file_idx, \"row_group_idx\": rg_idx, \"num_rows\": num_rows, \"start_row_idx_global\": total_rows})\n                    total_rows += num_rows\n\n            if self._rank == 0:\n                logger.debug(f\"DDP rank={self.ddp_rank}, row group infos: {row_group_infos}\")\n\n            if total_rows < self.ddp_world_size:\n                raise RuntimeError(f\"Total rows ({total_rows}) is less than DDP world size ({self.ddp_world_size}), \"\n                                   f\"cannot partition data across ranks. Please ensure enough data is available. \"\n                                   f\"Now the dataset type is {self.dataset_type}.\")\n\n            # Compute partition indices\n            if self.drop_last:\n                rows_per_rank = total_rows // self.ddp_world_size\n                total_used_rows = rows_per_rank * self.ddp_world_size\n                start = self.ddp_rank * rows_per_rank\n                end = start + rows_per_rank\n                if self._rank == 0:\n                    logger.warning(\n                        f\"DDP Rank {self.ddp_rank} using drop_last=True, partitioning rows into {self.ddp_world_size} ranks\"\n                        f\"with {rows_per_rank} rows each. Total used rows: {total_used_rows}, start={start}, end={end}. \"\n                        f\"Total rows: {total_rows}, total dropped rows: {total_rows - total_used_rows}. \"\n                    )\n            else:\n                # Distribute the remainder to the first (total_rows % ddp_world_size) ranks\n                rows_per_rank = total_rows // self.ddp_world_size\n                remainder = total_rows % self.ddp_world_size\n                if self.ddp_rank < remainder:\n                    start = self.ddp_rank * (rows_per_rank + 1)\n                    end = start + rows_per_rank + 1\n                else:\n                    start = remainder * (rows_per_rank + 1) + (self.ddp_rank - remainder) * rows_per_rank\n                    end = start + rows_per_rank\n                    self.is_trailing_rank = True # There is one less sample compared to the previous ranks.\n\n            if start >= end:\n                raise RuntimeError(f\"Rank {self.ddp_rank} assigned empty partition: start={start}, end={end}, total_rows={total_rows}\")\n\n            # Find which row groups overlap with [start, end)\n            selected_chunks = []\n            for info in row_group_infos:\n                rg_start = info[\"start_row_idx_global\"]\n                rg_end = rg_start + info[\"num_rows\"]\n                # If overlap\n                if rg_end > start and rg_start < end:\n                    # Compute local slice within this row group\n                    local_start = max(0, start - rg_start)\n                    local_end = min(info[\"num_rows\"], end - rg_start)\n                    selected_chunks.append({\"file_idx\": info[\"file_idx\"], \"row_group_idx\": info[\"row_group_idx\"], \"local_start\": local_start, \"local_end\": local_end})\n\n            # Read and slice the necessary row groups\n            tables = []\n            for chunk in selected_chunks:\n                pq_file = pq_files[chunk[\"file_idx\"]]\n                table = pq_file.read_row_group(chunk[\"row_group_idx\"])\n                if chunk[\"local_start\"] > 0 or chunk[\"local_end\"] < table.num_rows:\n                    table = table.slice(chunk[\"local_start\"], chunk[\"local_end\"] - chunk[\"local_start\"])\n                tables.append(table)\n\n            if not tables:\n                raise RuntimeError(f\"DDP Rank {self.ddp_rank} assigned rows [{start}, {end}) but failed to read any data.\")\n\n            final_table = pa.concat_tables(tables)\n            logger.debug(f\"DDP rank={self.ddp_rank} loaded {len(final_table)} rows from {len(tables)} row groups. start={start}, end={end}, total_rows={total_rows}.\")\n            return datasets.Dataset(final_table)\n\n        except Exception as e:\n            logger.error(f\"Failed during partitioned data loading for DDP rank {self.ddp_rank}: {dataset_files}. Error: {e}\")\n            raise\n\n    def _filter_overlong_prompts(self, raw_dataframe: datasets.Dataset) -> datasets.Dataset:\n        if self.filter_overlong_prompts:\n            original_len = len(raw_dataframe)\n            raw_dataframe = raw_dataframe.filter(\n                lambda doc: len(self.tokenizer.apply_chat_template(doc[self.prompt_key], add_generation_prompt=True)) <= self.max_prompt_length,\n                num_proc=self.num_workers,\n                desc=f\"Rank {self.ddp_rank} filtering prompts longer than {self.max_prompt_length} tokens\",\n            )\n            filtered_len = len(raw_dataframe)\n            if self._rank == 0:\n                logger.info(f\"Filtered prompts from {original_len} to {filtered_len} on each rank.\")\n        return raw_dataframe\n\n    def __len__(self) -> int:\n        return len(self.processed_dataframe) if self.processed_dataframe is not None else 0\n\n    def _build_messages(self, example: dict) -> list:\n        \"\"\"Helper function to structure messages, adopted from RLHFDataset.\"\"\"\n        # The pop operation is safe here because map creates a copy for each process\n        messages: list = example.pop(self.prompt_key)\n\n        if self.image_key in example or self.video_key in example:\n            for message in messages:\n                content = message[\"content\"]\n                content_list = []\n                # Simple split logic to handle interleaved text and images/videos\n                for segment in re.split(\"(<image>|<video>)\", content):\n                    if segment == \"<image>\":\n                        content_list.append({\"type\": \"image\"})\n                    elif segment == \"<video>\":\n                        content_list.append({\"type\": \"video\"})\n                    elif segment:  # Avoid adding empty strings\n                        content_list.append({\"type\": \"text\", \"text\": segment})\n                message[\"content\"] = content_list\n        return messages\n\n    def _preprocess_function(self, row_dict: Dict) -> Dict:\n        \"\"\"\n        The core preprocessing logic applied to each sample via `datasets.map()`.\n        \"\"\"\n        processed_row = row_dict.copy()\n        messages = self._build_messages(processed_row)\n        model_inputs = {}\n\n        # The output dict of this function becomes a row in the new dataset\n\n        if self.processor is not None:\n            # Note: Vision processing is kept here for simplicity.\n            # For extreme performance, you might consider pre-serializing images/videos.\n            from siirl.data_coordinator.dataloader.vision_utils import process_image, process_video\n\n            raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n\n            multi_modal_data = {}\n            images = None\n            if self.image_key in processed_row:\n                images = [process_image(image, self.image_max_pixels, self.image_min_pixels) for image in processed_row.pop(self.image_key)]\n                multi_modal_data[\"image\"] = images\n            videos = None\n            if self.video_key in processed_row:\n                videos = [process_video(video, fps=self.video_fps, fps_max_frames=self.video_maxlen, \n                                        max_pixels=self.video_max_pixels, min_pixels=self.video_min_pixels) \n                                        for video in processed_row.pop(self.video_key)]\n                multi_modal_data[\"video\"] = [video.numpy() for video in videos]\n\n            model_inputs = self.processor(text=[raw_prompt], images=images, videos=videos, return_tensors=\"pt\")\n            input_ids = model_inputs.pop(\"input_ids\")\n            attention_mask = model_inputs.pop(\"attention_mask\")\n\n            if \"second_per_grid_ts\" in model_inputs:\n                model_inputs.pop(\"second_per_grid_ts\")\n\n            # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature\n            processed_row[\"multi_modal_data\"] = multi_modal_data\n            processed_row[\"multi_modal_inputs\"] = dict(model_inputs)\n\n            # second_per_grid_ts isn't used for training, just for mrope\n            processed_row[\"multi_modal_inputs\"].pop(\"second_per_grid_ts\", None)\n        else:\n            raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n            model_inputs = self.tokenizer(raw_prompt, return_tensors=\"pt\", add_special_tokens=False)\n            input_ids = model_inputs.pop(\"input_ids\")\n            attention_mask = model_inputs.pop(\"attention_mask\")\n\n        input_ids, attention_mask = F.postprocess_data(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            max_length=self.max_prompt_length,\n            pad_token_id=self.tokenizer.pad_token_id,\n            left_pad=True,\n            truncation=self.truncation,\n        )\n\n        if self.processor is not None and self.processor.image_processor.__class__.__name__ == \"Qwen2VLImageProcessor\":\n            from siirl.models.transformers.qwen2_vl import get_rope_index\n\n            position_ids = [\n                get_rope_index(\n                    self.processor,\n                    input_ids=input_ids[0],\n                    image_grid_thw=model_inputs.get(\"image_grid_thw\"),\n                    video_grid_thw=model_inputs.get(\"video_grid_thw\"),\n                    second_per_grid_ts=model_inputs.get(\"second_per_grid_ts\"),\n                    attention_mask=attention_mask[0],\n                )\n            ]  # (1, 3, seq_len)\n\n        else:\n            position_ids = compute_position_id_with_mask(attention_mask)\n\n        processed_row[\"input_ids\"] = input_ids[0]\n        processed_row[\"attention_mask\"] = attention_mask[0]\n        processed_row[\"position_ids\"] = position_ids[0]\n\n        # Handle raw_prompt_ids for potential combination with other templates\n        raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)\n        if len(raw_prompt_ids) > self.max_prompt_length:\n            if self.truncation == \"left\":\n                raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]\n            elif self.truncation == \"right\":\n                raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]\n            elif self.truncation == \"middle\":\n                left_half = self.max_prompt_length // 2\n                right_half = self.max_prompt_length - left_half\n                raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]\n            elif self.truncation == \"error\":\n                raise RuntimeError(f\"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.\")\n\n        processed_row[\"raw_prompt_ids\"] = raw_prompt_ids\n        if self.return_raw_chat:\n            processed_row[\"raw_prompt\"] = messages\n        if self.return_full_prompt:\n            processed_row[\"full_prompts\"] = raw_prompt  # array of strings\n\n        # add index for each prompt\n        if self.multi_turn:\n            index = processed_row.get(\"extra_info\", {}).get(\"index\", 0)\n            tools_kwargs = processed_row.get(\"extra_info\", {}).get(\"tools_kwargs\", {})\n            interaction_kwargs = processed_row.get(\"extra_info\", {}).get(\"interaction_kwargs\", {})\n            # need_tools_kwargs = row_dict.get(\"extra_info\", {}).get(\"need_tools_kwargs\", self.need_tools_kwargs)\n            # if need_tools_kwargs and not tools_kwargs:\n            #     logger.warning(\"tools_kwargs is empty for index {}, data source: {}\", index, row_dict[\"data_source\"])\n            processed_row[\"index\"] = index\n            processed_row[\"tools_kwargs\"] = tools_kwargs\n            processed_row[\"interaction_kwargs\"] = interaction_kwargs\n        return processed_row\n\n    def __getitem__(self, item: int) -> Dict:\n        \"\"\"\n        Returns a preprocessed item from the dataset.\n        \"\"\"\n        if self.processed_dataframe is None:\n            raise IndexError(\"Dataset is empty or not initialized properly.\")\n        \n        # For embodied, __getitem__ is a simple lookup.\n        # For LLM, it may involve on-the-fly preprocessing.\n        if self.dataset_type == \"embodied\":\n            return self.processed_dataframe[item]\n        else:\n            return self.processed_dataframe[item] if not self.load_on_the_fly else self._preprocess_function(self.processed_dataframe[item])\n"
  },
  {
    "path": "siirl/data_coordinator/dataloader/vision_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom io import BytesIO\nfrom typing import Optional, Union\n\nimport torch\nfrom PIL import Image\nfrom qwen_vl_utils import fetch_image, fetch_video, smart_resize\n\n\ndef process_image(image: Union[dict, Image.Image], max_pixels: int, min_pixels: int) -> Image.Image:\n    img_obj = None\n    if isinstance(image, Image.Image):\n        img_obj = image.convert(\"RGB\")\n    elif \"bytes\" in image:\n        assert \"image\" not in image, \"Cannot have both `bytes` and `image`\"\n        img_obj = Image.open(BytesIO(image[\"bytes\"])).convert(\"RGB\")\n\n    if img_obj:\n        width, height = img_obj.size\n        resized_height, resized_width = smart_resize(\n            height,\n            width,\n            min_pixels=min_pixels,\n            max_pixels=max_pixels,\n        )\n        return img_obj.resize((resized_width, resized_height))\n\n    image[\"min_pixels\"] = min_pixels\n    image[\"max_pixels\"] = max_pixels\n    return fetch_image(image)\n\n\nVIDEO_FORMAT_HELP = \"\"\"Currently, we only support the video formats introduced in qwen2-vl.\nRefer to https://github.com/QwenLM/Qwen2.5-VL?tab=readme-ov-file#using---transformers-to-chat.\n\neg.\n{\n    \"type\": \"video\",\n    \"video\": [\n        \"file:///path/to/frame1.jpg\",\n        \"file:///path/to/frame2.jpg\"\n    ]\n}\n\n{\n    \"type\": \"video\",\n    \"video\": \"file:///path/to/video.mp4\"\n}\n# Defaults to fps=2, min_frames=4, max_frames=768\n\n{\n    \"type\": \"video\",\n    \"video\": \"file:///path/to/video.mp4\",\n    \"fps\": 2,\n    \"min_frames\": 1,\n    \"max_frames\": 32\n}\n\"\"\"\n\n\ndef process_video(\n    video: dict,\n    nframes: Optional[int] = None,\n    fps: Optional[float] = None,\n    fps_min_frames: Optional[int] = None,\n    fps_max_frames: Optional[int] = None,\n    max_pixels: Optional[int] = None,\n    min_pixels: Optional[int] = None,\n) -> torch.Tensor:\n    \"\"\"Converts a video dict into a [n_frames, 3, H, W] tensor\n\n    Add video sample FPS in a future MR\n    \"\"\"\n\n    if not isinstance(video, dict) or \"video\" not in video:\n        raise NotImplementedError(VIDEO_FORMAT_HELP)\n    assert nframes is None or fps is None, \"Can't use both `nframes` or `fps`\"\n\n    # Shallow copy... since we might want to add some keys\n    video = dict(video)\n\n    if max_pixels is not None:\n        video[\"max_pixels\"] = max_pixels\n    if min_pixels is not None:\n        video[\"min_pixels\"] = min_pixels\n    contains_sampling_rules = \"nframes\" in video or \"fps\" in video\n    if not contains_sampling_rules:\n        if nframes is not None:\n            video[\"nframes\"] = nframes\n        elif fps is not None:\n            video[\"fps\"] = fps\n            if fps_min_frames is not None:\n                video[\"min_frames\"] = fps_min_frames\n            if fps_max_frames is not None:\n                video[\"max_frames\"] = fps_max_frames\n\n    return fetch_video(video)\n\n\ndef process_multi_modal_inputs_for_minicpmo(input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs):\n    # Adjust image bounds based on left padding and cumulative sequence lengths\n    # This is necessary for MiniCPM-o's vision-language alignment\n    left_padding_length = torch.argmax(attention_mask, dim=1)\n    image_bounds = []\n    for i in range(len(multi_modal_inputs[\"image_bound\"])):\n        image_bound = (\n            multi_modal_inputs[\"image_bound\"][i].to(left_padding_length.device) - left_padding_length[i] + cu_seqlens[i]\n        )\n        image_bounds.append(image_bound)\n\n    # Flatten pixel values list for MiniCPM-o processing\n    pixel_values = []\n    for i in range(len(multi_modal_inputs[\"pixel_values\"])):\n        pixel_values.extend([p for p in multi_modal_inputs[\"pixel_values\"][i]])\n\n    multi_modal_inputs[\"pixel_values\"] = [pixel_values]\n    multi_modal_inputs[\"image_bound\"] = [torch.vstack(image_bounds)]\n    multi_modal_inputs[\"tgt_sizes\"] = [torch.vstack(multi_modal_inputs[\"tgt_sizes\"])]\n    multi_modal_inputs[\"input_ids\"] = input_ids\n    multi_modal_inputs[\"attention_mask\"] = attention_mask\n    multi_modal_inputs[\"position_ids\"] = position_ids\n    return {\"data\": multi_modal_inputs}\n"
  },
  {
    "path": "siirl/data_coordinator/protocol.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nImplement base data transfer protocol between any two functions, modules.\nWe can subclass Protocol to define more detailed batch info with specific keys\n\"\"\"\n\nimport contextlib\nimport copy\nimport logging\nimport os\nimport pickle\nfrom copy import deepcopy\nfrom dataclasses import dataclass, field\nfrom typing import Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport pandas as pd\nimport ray\nimport tensordict\nimport torch\nimport torch.distributed\nimport torch.nn.functional as F\nfrom packaging import version\nfrom tensordict import TensorDict\nfrom torch.utils.data import DataLoader\n\nfrom siirl.utils.extras.device import get_device_id, get_torch_device\nfrom siirl.utils.extras.py_functional import union_two_dict\nfrom siirl.utils.model_utils.torch_functional import allgather_dict_tensors\n\n__all__ = [\"union_tensor_dict\"]\n\nwith contextlib.suppress(Exception):\n    tensordict.set_lazy_legacy(False).set()\n\n\nclass _TensorDictConfigMeta(type):\n    _config = {}\n\n    auto_padding_key = \"_siirl_auto_padding\"\n\n    @property\n    def auto_padding(cls):\n        enabled_by_env = os.getenv(\"SIIRL_AUTO_PADDING\", \"FALSE\").upper() in [\"TRUE\", \"1\"]\n        return enabled_by_env or cls._config.get(cls.auto_padding_key, False)\n\n    @auto_padding.setter\n    def auto_padding(cls, enabled: bool):\n        assert isinstance(enabled, bool), f\"enabled must be a boolean, got {enabled} as {type(enabled)}\"\n        cls._config[cls.auto_padding_key] = enabled\n\n\nclass TensorDictConfig(metaclass=_TensorDictConfigMeta):\n    pass\n\n\n_padding_size_key = \"_padding_size_key_x123d\"\n\n\ndef union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:\n    \"\"\"Union two tensordicts.\"\"\"\n    assert tensor_dict1.batch_size == tensor_dict2.batch_size, (\n        f\"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}\"\n    )\n    for key in tensor_dict2.keys():\n        if key not in tensor_dict1.keys():\n            tensor_dict1[key] = tensor_dict2[key]\n        else:\n            assert tensor_dict1[key].equal(tensor_dict2[key]), (\n                f\"{key} in tensor_dict1 and tensor_dict2 are not the same object\"\n            )\n\n    return tensor_dict1\n\n\ndef union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str, np.ndarray]) -> dict[str, np.ndarray]:\n    for key, val in tensor_dict2.items():\n        if key in tensor_dict1:\n            assert isinstance(tensor_dict2[key], np.ndarray)\n            assert isinstance(tensor_dict1[key], np.ndarray)\n            # to properly deal with nan and object type\n            assert pd.DataFrame(tensor_dict2[key]).equals(pd.DataFrame(tensor_dict1[key])), (\n                f\"{key} in tensor_dict1 and tensor_dict2 are not the same object\"\n            )\n        tensor_dict1[key] = val\n\n    return tensor_dict1\n\n\ndef list_of_dict_to_dict_of_list(list_of_dict: list[dict]):\n    if len(list_of_dict) == 0:\n        return {}\n    keys = list_of_dict[0].keys()\n    output = {key: [] for key in keys}\n    for data in list_of_dict:\n        for key, item in data.items():\n            assert key in output\n            output[key].append(item)\n    return output\n\n\ndef all_gather_data_proto(data: TensorDict, process_group):\n    # Note that this is an inplace operator just like torch.distributed.all_gather\n    group_size = torch.distributed.get_world_size(group=process_group)\n    print(f\"all gather dataproto process_group size:{group_size}\")\n    assert isinstance(data, TensorDict)\n    prev_device = data.device\n    data = data.to(get_device_id())\n    data = allgather_dict_tensors(data.contiguous(), size=group_size, group=process_group, dim=0)\n    data = data.to(prev_device)\n\n\ndef select_idxs(batch: TensorDict, idxs):\n    \"\"\"\n    Select specific indices from the TensorDict.\n\n    Args:\n        batch (TensorDict): data to be select\n        idxs (torch.Tensor or numpy.ndarray or list): Indices to select\n\n    Returns:\n        TensorDict: A new TensorDict containing only the selected indices\n    \"\"\"\n    if isinstance(idxs, list):\n        idxs = torch.tensor(idxs)\n        if idxs.dtype != torch.bool:\n            idxs = idxs.type(torch.int32)\n\n    if isinstance(idxs, np.ndarray):\n        idxs_np = idxs\n        idxs_torch = torch.from_numpy(idxs)\n    else:  # torch.Tensor\n        idxs_torch = idxs\n        idxs_np = idxs.detach().cpu().numpy()\n\n    batch_size = int(idxs_np.sum()) if idxs_np.dtype == bool else idxs_np.shape[0]\n    filtered_data = {}\n    for key, value in batch.items():\n        if isinstance(value, torch.Tensor):\n            filtered_data[key] = value[idxs_torch]\n        elif isinstance(value, np.ndarray):\n            filtered_data[key] = value[idxs_np]\n        elif isinstance(value, list):\n            filtered_data[key] = np.ndarray(value)[idxs_np]\n        else:\n            filtered_data[key] = value\n    return TensorDict(filtered_data, batch_size=batch_size)"
  },
  {
    "path": "siirl/data_coordinator/sample.py",
    "content": "import torch\nimport numpy as np\nimport asyncio\nimport ray\nimport uuid\nfrom pydantic import BaseModel, Field\nfrom typing import Any, Dict, List, Optional, Union, Set\nfrom tensordict import TensorDict\n# Handle different tensordict versions - NonTensorData location varies\ntry:\n    from tensordict import NonTensorData\nexcept ImportError:\n    from tensordict.tensorclass import NonTensorData\nfrom typing import get_args, get_origin\n\n\nclass SampleInfo(BaseModel):\n    agent_group: int = Field(default=0)\n    sum_tokens: int = Field(default=0)\n    prompt_length: int = Field(default=0)\n    response_length: int = Field(default=0)\n    dict_info: Dict[str, Any] = Field(default_factory=dict)\n    uid: Optional[str] = Field(default=None)\n    node_id: Optional[str] = Field(default=None)\n\n\nclass Sample(BaseModel):\n    # from tensordict of Dataproto\n    prompts: Optional[np.ndarray] = Field(default=None)\n    responses: Optional[np.ndarray] = Field(default=None)\n    response_mask: Optional[np.ndarray] = Field(default=None)\n    input_ids: Optional[np.ndarray] = Field(\n        default=None,\n        metadata={\"help\": \"initial: prompts with pad, after rollout: prompts + response\"}\n    )\n    attention_mask: Optional[np.ndarray] = Field(default=None)\n    position_ids: Optional[np.ndarray] = Field(default=None)\n    acc: float = Field(default=None)\n    token_level_rewards: Optional[np.ndarray] = Field(default=None)\n    token_level_scores: Optional[np.ndarray] = Field(default=None)\n    values: Optional[np.ndarray] = Field(default=None)\n    advantages: Optional[np.ndarray] = Field(default=None)\n    returns: Optional[np.ndarray] = Field(default=None)\n    old_log_probs: Optional[np.ndarray] = Field(default=None)\n    ref_log_prob: Optional[np.ndarray] = Field(default=None)\n    \n    # used for vla\n    pixel_values: Optional[np.ndarray] = Field(default=None)\n    finish_step: Optional[np.ndarray] = Field(default=None)\n    complete: Optional[np.ndarray] = Field(default=None)\n    task_file_name: Optional[np.ndarray] = Field(default=None)\n    vjepa_embedding: Optional[np.ndarray] = Field(default=None)\n    \n    # from  non_tensor_batch of Dataproto\n    raw_prompt: str = Field(default=\"\")\n    raw_prompt_ids: List[int] = Field(default_factory=list)\n    prompt_texts: str = Field(\n        default=\"\",\n        metadata={\"help\": \"used in validate, decode form 'input_ids', but may is same with raw_prompt\"}\n    )\n    reward_model: Dict[str, Any] = Field(default_factory=dict)\n    data_source: str = Field(\n        default=\"\",\n        metadata={\"help\": \"name of datasets\"}\n    )\n    extra_info: Dict[str, Any] = Field(default_factory=dict)\n    tools_kwargs: Dict[str, Any] = Field(\n        default_factory=dict,\n        metadata={\"help\": \"used in multi-turn, may not need to be sample dimension\"}\n    )\n    interaction_kwargs: Dict[str, Any] = Field(\n        default_factory=dict,\n        metadata={\"help\": \"used in multi-turn, may not need to be sample dimension\"}\n    )\n    request_id: str = Field(\n        default=\"\",\n        metadata={\"help\": \"used in multi-agent\"}\n    )\n    traj_len: int = Field(\n        default=None,\n        metadata={\"help\": \"used in multi-agent\"}\n    )\n    traj_step: int = Field(\n        default=None,\n        metadata={\"help\": \"used in multi-agent\"}\n    )\n    seq_final_reward: float = Field(\n        default=None,\n        metadata={\"help\": \"used in dapo\"}\n    )\n    seq_reward: float = Field(\n        default=None,\n        metadata={\"help\": \"used in dapo\"}\n    )\n    multi_modal_inputs: Optional[Dict[str, Any]] = Field(default=None)\n    uid: Optional[str] = Field(default=None)\n\n    temperature: float = Field(\n        default=None,\n        metadata={\"help\": \"temperature\"}\n    )\n\n    class Config:\n        arbitrary_types_allowed = True\n\n\nclass SampleManager(BaseModel):\n    sample_info: Optional[SampleInfo] = Field(default=None)\n    sample: Optional[Union[Sample, ray.ObjectRef]] = Field(default=None)\n\n    class Config:\n        arbitrary_types_allowed = True\n\n\n\ndef preprocess_dataloader(data:Dict, n:int = 1):\n    from loguru import logger\n    # Manually repeat all numpy arrays and torch tensors\n    batch_size = None\n    for key, value in data.items():\n        if isinstance(value, np.ndarray):\n            # Repeat numpy arrays along axis 0\n            data[key] = np.repeat(value, n, axis=0)\n        elif isinstance(value, torch.Tensor):\n            # Repeat torch tensors along dim 0\n            if batch_size is None:\n                batch_size = value.shape[0]\n            data[key] = value.repeat_interleave(n, dim=0)\n            \n        elif isinstance(value, list):\n            # Convert list to numpy array and repeat\n            data[key] = np.repeat(np.array(value), n, axis=0)\n            \n    # Create UUID indices for GRPO grouping\n    # Each prompt gets a unique UUID, then repeated n times\n    uid = np.array([str(uuid.uuid4()) for _ in range(batch_size)])\n    data['uid'] = np.repeat(uid, n, axis=0)\n    # Now all fields have batch_size * n\n    # Create TensorDict with the expanded batch size\n    tensor_dict = TensorDict(data, batch_size=batch_size * n)\n\n    return tensor_dict\n\ndef Dict2Samples(data:TensorDict)-> List[SampleManager]:\n    batch_size = data.batch_size[0]\n    async def calc_sample(index):\n        local_sample = Sample()\n        local_sample.input_ids = data['input_ids'][index].numpy() if 'input_ids' in data else None\n        local_sample.attention_mask = data['attention_mask'][index].numpy() if 'attention_mask' in data else None\n        local_sample.position_ids = data['position_ids'][index].numpy() if 'position_ids' in data else None\n        local_sample.data_source = data['data_source'][index] if 'data_source' in data else None\n        local_sample.reward_model = data['reward_model'][index] if 'reward_model' in data else None\n        local_sample.prompts = data['prompts'][index].numpy() if 'prompts' in data else None\n        local_sample.responses = data['responses'][index].numpy() if 'responses' in data else None\n        local_sample.response_mask = data['response_mask'][index].numpy() if 'response_mask' in data else None\n        local_sample.values = data['values'][index].numpy() if 'values' in data else None\n        local_sample.raw_prompt_ids = data['raw_prompt_ids'][index] if 'raw_prompt_ids' in data else None\n        local_sample.advantages = data['advantages'][index].numpy() if 'advantages' in data else None\n        local_sample.raw_prompt = data['raw_prompt'][index] if 'raw_prompt' in data else None\n        local_sample.returns = data['returns'][index].numpy() if 'returns' in data else None\n        local_sample.token_level_rewards = data['token_level_rewards'][index].numpy() if 'token_level_rewards' in data else None\n        local_sample.token_level_scores = data['token_level_scores'][index].numpy() if 'token_level_scores' in data else None\n        local_sample.old_log_probs = data['old_log_probs'][index].numpy() if 'old_log_probs' in data else None\n        local_sample.ref_log_prob = data['ref_log_prob'][index].numpy() if 'ref_log_prob' in data else None\n        local_sample.extra_info = data['extra_info'][index] if 'extra_info' in data else None\n        local_sample.pixel_values = data['pixel_values'][index].numpy() if 'pixel_values' in data else None\n        local_sample.finish_step = data['finish_step'][index].numpy() if 'finish_step' in data else None\n        local_sample.complete = data['complete'][index].numpy() if 'complete' in data else None\n        local_sample.task_file_name = data['task_file_name'][index].numpy() if 'task_file_name' in data else None\n        local_sample.vjepa_embedding = data['vjepa_embedding'][index].numpy() if 'vjepa_embedding' in data else None\n        if 'multi_modal_inputs' in data:\n            local_sample.multi_modal_inputs = data[\"multi_modal_inputs\"][index]\n        local_sample.uid = data['uid'][index]\n        # local_sample = ray.put(local_sample)\n        return local_sample\n    loop = asyncio.get_event_loop()\n    futures = []\n    for index in range(batch_size):\n        futures.append(calc_sample(index))\n    samples = loop.run_until_complete(asyncio.gather(*futures))   \n    del data \n    return samples\n\ndef Samples2Dict(samples: List[Sample]) -> TensorDict:\n    async def get_sample(samples, index):\n        # sample = ray.get(samples[index].sample)\n        # samples[index].sample = sample\n        return samples[index]\n    futures = []\n    for i in range(len(samples)):\n        futures.append(get_sample(samples, i))\n    loop = asyncio.get_event_loop()\n    samples = loop.run_until_complete(asyncio.gather(*futures))\n    # convert to tensordict\n    fields = Sample.model_fields \n    sample_fields = [name for name in fields.keys()]\n\n    aggregated: Dict[str, List[Any]] = {}\n    for sample in samples:\n        if sample is None:\n            raise ValueError(\"Sample Should not be None\")\n        for field in sample_fields:\n            val = getattr(sample, field)\n            if val is not None:\n                if isinstance(val, (torch.Tensor, list, np.ndarray, dict, str)):\n                    if field not in aggregated:\n                        aggregated[field] = []\n                    aggregated[field].append(val)\n                elif isinstance(val, (int, float, bool)):\n                    aggregated[field] = val\n                else:\n                    print(f\"key {field} type{type(val)} not support\")       \n    tensordict_data: Dict[str, Any] = {}\n    batch_size = (len(samples),)  \n\n    for key, values in aggregated.items():\n        if isinstance(values, list):\n            first_val = values[0]\n            \n            # if internal val is not \"\"/ {} ...\n            if isinstance(first_val, np.ndarray):\n                tensordict_data[key] = np.stack(values, axis=0) if first_val.ndim >= 1 else np.array(values)\n                default_type = fields[key].annotation\n                if get_origin(default_type) is Union:\n                    args = get_args(default_type)       \n                    actual_type = next((arg for arg in args if arg is not type(None)), None)\n                    if actual_type is np.ndarray:\n                        tensordict_data[key] = torch.tensor(tensordict_data[key]) \n                elif default_type is np.ndarray:\n                    tensordict_data[key] = torch.tensor(tensordict_data[key])\n            elif isinstance(first_val, str):\n                if first_val:\n                    tensordict_data[key] = values\n            else:\n                if first_val:\n                    tensordict_data[key] = NonTensorData(\n                        data=values,\n                        batch_size=batch_size\n                    )\n\n        else:\n            tensordict_data[key] = NonTensorData(\n                data=values,\n                batch_size=batch_size\n            )\n\n    tensordict_data[\"global_token_num\"] = NonTensorData(torch.sum(tensordict_data[\"attention_mask\"], dim=-1).flatten().tolist())\n\n    return TensorDict(tensordict_data, batch_size=batch_size)\n\ndef filter_tensordict(batch: TensorDict, indices: List[int]) -> TensorDict:\n    \"\"\"\n    Filter a TensorDict by selecting only the samples at the specified indices.\n    \n    This function is used by DAPO to filter out trajectory groups with zero variance.\n    It properly handles both regular tensor fields and NonTensorData fields.\n    \n    Args:\n        batch: The input TensorDict to filter\n        indices: List of indices to keep\n        \n    Returns:\n        A new TensorDict containing only the selected samples\n    \"\"\"\n    if not indices:\n        # Return an empty TensorDict with the same structure but batch_size=0\n        return TensorDict({}, batch_size=(0,))\n    \n    # Convert indices to both tensor and numpy for different field types\n    indices_tensor = torch.tensor(indices, dtype=torch.long)\n    indices_np = np.array(indices, dtype=np.int64)\n    target_batch_size = len(indices)\n    original_batch_size = batch.batch_size[0] if isinstance(batch.batch_size, tuple) else batch.batch_size\n    \n    from loguru import logger\n    \n    # Manually filter each field to ensure NonTensorData is handled correctly\n    filtered_dict = {}\n    for key, value in batch.items():\n        try:\n            if isinstance(value, NonTensorData):\n                # NonTensorData needs special handling\n                if isinstance(value.data, np.ndarray):\n                    # numpy array wrapped in NonTensorData\n                    data_len = len(value.data)\n                    if data_len == original_batch_size:\n                        # This is batched data - filter it\n                        filtered_data = value.data[indices_np]\n                        filtered_dict[key] = NonTensorData(data=filtered_data, batch_size=[target_batch_size])\n                    else:\n                        # This is metadata (length doesn't match batch_size) - keep as is\n                        filtered_dict[key] = value\n                elif isinstance(value.data, (list, tuple)):\n                    # list/tuple wrapped in NonTensorData\n                    data_len = len(value.data)\n                    if data_len == original_batch_size:\n                        # This is batched data - filter it\n                        filtered_data = [value.data[i] for i in indices]\n                        filtered_dict[key] = NonTensorData(data=filtered_data, batch_size=[target_batch_size])\n                    else:\n                        # This is metadata (length doesn't match batch_size) - keep as is\n                        filtered_dict[key] = value\n                else:\n                    # scalar or other type - keep as is (it's metadata)\n                    filtered_dict[key] = value\n            elif isinstance(value, torch.Tensor):\n                # Regular tensor - use tensor indexing\n                filtered_dict[key] = value[indices_tensor]\n            else:\n                # Other types - try tensor indexing or keep as is\n                try:\n                    filtered_dict[key] = value[indices_tensor]\n                except:\n                    filtered_dict[key] = value\n        except Exception as e:\n            # Provide detailed error message for debugging\n            raise RuntimeError(\n                f\"Error filtering field '{key}' in filter_tensordict: {e}\\n\"\n                f\"  Field type: {type(value)}\\n\"\n                f\"  Value type: {type(value.data) if isinstance(value, NonTensorData) else 'N/A'}\\n\"\n                f\"  Data length: {len(value.data) if isinstance(value, NonTensorData) and hasattr(value.data, '__len__') else 'N/A'}\\n\"\n                f\"  Original batch size: {original_batch_size}\\n\"\n                f\"  Target batch size: {target_batch_size}\\n\"\n                f\"  Indices: {indices[:10]}... (showing first 10)\\n\"\n                f\"  Max index: {max(indices) if indices else 'N/A'}\"\n            ) from e\n    \n    # Create new TensorDict with filtered data\n    filtered_batch = TensorDict(filtered_dict, batch_size=target_batch_size)\n    \n    return filtered_batch"
  },
  {
    "path": "siirl/engine/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/engine/actor/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base import BasePPOActor\nfrom .dp_actor import DataParallelPPOActor\nfrom .embodied_actor import RobDataParallelPPOActor\n__all__ = [\"BasePPOActor\", \"DataParallelPPOActor\",\"RobDataParallelPPOActor\"]\n"
  },
  {
    "path": "siirl/engine/actor/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe base class for Actor\n\"\"\"\n\nfrom abc import ABC, abstractmethod\nfrom typing import Dict\n\nimport torch\n\nfrom tensordict import TensorDict\n\n__all__ = [\"BasePPOActor\"]\n\n\nclass BasePPOActor(ABC):\n    def __init__(self, config):\n        \"\"\"The base class for PPO actor\n\n        Args:\n            config (DictConfig): a config passed to the PPOActor. We expect the type to be\n                DictConfig (https://omegaconf.readthedocs.io/), but it can be any namedtuple in general.\n        \"\"\"\n        super().__init__()\n        self.config = config\n\n    @abstractmethod\n    def compute_log_prob(self, data: TensorDict) -> torch.Tensor:\n        \"\"\"Compute logits given a batch of data.\n\n        Args:\n            data (TensorDict): a batch of data represented by TensorDict. It must contain key ```input_ids```,\n                ```attention_mask``` and ```position_ids```.\n\n        Returns:\n            TensorDict: a TensorDict containing the key ```log_probs```\n\n\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def update_policy(self, data: TensorDict) -> Dict:\n        \"\"\"Update the policy with an iterator of TensorDict\n\n        Args:\n            data (TensorDict): an iterator over the TensorDict that returns by\n                ```make_minibatch_iterator```\n\n        Returns:\n            Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model\n            such as ```loss```, ```grad_norm```, etc,.\n\n        \"\"\"\n        pass\n"
  },
  {
    "path": "siirl/engine/actor/dp_actor.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nSingle Process Actor\n\"\"\"\n\nimport itertools\nfrom typing import Tuple\n\nimport torch\nimport numpy as np\nfrom tensordict import TensorDict\n# Handle different tensordict versions - NonTensorData location varies\ntry:\n    from tensordict import NonTensorData\nexcept ImportError:\n    from tensordict.tensorclass import NonTensorData\nfrom loguru import logger\nfrom torch import nn\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.tensor import DTensor\n\nimport siirl.utils.model_utils.torch_functional as F\nfrom siirl.dag_worker.core_algos import agg_loss, compute_policy_loss, kl_penalty, get_policy_loss_fn\nfrom siirl.utils.debug import GPUMemoryLogger\nfrom siirl.utils.extras.device import get_device_id, get_device_name, is_cuda_available, is_npu_available\nfrom siirl.utils.extras.py_functional import append_to_dict\nfrom siirl.utils.model_utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_\nfrom siirl.utils.extras.py_functional import append_to_dict\nfrom siirl.utils.model_utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\nfrom siirl.utils.model_utils.torch_functional import logprobs_from_logits\nfrom siirl.utils.model_utils.ulysses import gather_outpus_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs\nfrom siirl.params import ActorArguments, RefArguments\nfrom siirl.engine.actor import BasePPOActor\nfrom siirl.utils.model_utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch\nif is_cuda_available:\n    from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input\nelif is_npu_available:\n    from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input\n\n\n__all__ = [\"DataParallelPPOActor\"]\n\n\nclass DataParallelPPOActor(BasePPOActor):\n    def __init__(\n        self,\n        config: ActorArguments | RefArguments,\n        actor_module: nn.Module,\n        actor_optimizer: torch.optim.Optimizer = None,\n    ):\n        \"\"\"When optimizer is None, it is Reference Policy\"\"\"\n        super().__init__(config)\n        self.actor_module = actor_module\n        self.actor_optimizer = actor_optimizer\n        role = \"Ref\" if actor_optimizer is None else \"Actor\"\n\n        self.use_remove_padding = self.config.use_remove_padding\n        if torch.distributed.get_rank() == 0:\n            logger.info(f\"{role} use_remove_padding={self.use_remove_padding}\")\n        self.use_fused_kernels = self.config.use_fused_kernels\n        if torch.distributed.get_rank() == 0:\n            logger.info(f\"{role} use_fused_kernels={self.use_fused_kernels}\")\n        self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size\n        self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1\n\n        if self.config.entropy_from_logits_with_chunking:\n            entropy_from_logits = F.entropy_from_logits_with_chunking\n        else:\n            entropy_from_logits = F.entropy_from_logits\n\n        self.compute_entropy_from_logits = (\n            torch.compile(entropy_from_logits, dynamic=True)\n            if self.config.use_torch_compile  #  use torch compile by default\n            else entropy_from_logits\n        )\n        self.device_name = get_device_name()\n\n    def _forward_micro_batch(\n        self, micro_batch, temperature, calculate_entropy=False\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Returns:\n            entropy: # (bs, response_len)\n            log_probs: # (bs, response_len)\n        \"\"\"\n        response_length = micro_batch[\"responses\"].size(-1)\n        multi_modal_inputs = {}\n        if \"multi_modal_inputs\" in micro_batch.keys():\n            from siirl.utils.model_utils.model import extract_multi_modal_inputs\n            multi_modal_inputs = extract_multi_modal_inputs(micro_batch[\"multi_modal_inputs\"])\n        \n        with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):\n            input_ids = micro_batch[\"input_ids\"]\n            batch_size, seqlen = input_ids.shape\n            attention_mask = micro_batch[\"attention_mask\"]\n            position_ids = micro_batch[\"position_ids\"]\n            entropy = None\n            if position_ids.dim() == 3:  # qwen2vl mrope\n                position_ids = position_ids.transpose(0, 1)  # (bsz, 3, seqlen) -> (3, bsz, seqlen)\n\n            if self.use_remove_padding:\n                # input_ids_rmpad (total_nnz, ...)\n                input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask)\n                input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n                # unpad the position_ids to align the rotary\n                if position_ids.dim() == 3:\n                    position_ids_rmpad = index_first_axis(rearrange(position_ids, \"c b s ... -> (b s) c ...\"), indices).transpose(0, 1).unsqueeze(1)  # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)\n                else:\n                    position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices).transpose(0, 1)\n                if \"image_bound\" in multi_modal_inputs:\n                    from siirl.data_coordinator.dataloader.vision_utils import process_multi_modal_inputs_for_minicpmo\n                    multi_modal_inputs = process_multi_modal_inputs_for_minicpmo(\n                        input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs\n                    )\n\n                # for compute the log_prob\n                input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1)  # (1, total_nnz)\n\n                # pad and slice the inputs if sp > 1\n                if self.use_ulysses_sp:\n                    is_vlm_model = hasattr(\n                        getattr(self.actor_module, \"module\", self.actor_module).config, \"vision_config\"\n                    )\n                    if is_vlm_model:\n                        # vlm model's inputs will be sliced after embedding\n                        input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad(\n                            input_ids_rmpad,\n                            position_ids_rmpad=position_ids_rmpad,\n                            sp_size=self.ulysses_sequence_parallel_size,\n                        )\n                    else:\n                        input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(\n                            input_ids_rmpad,\n                            position_ids_rmpad=position_ids_rmpad,\n                            sp_size=self.ulysses_sequence_parallel_size,\n                        )\n                    input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(\n                        input_ids_rmpad_rolled,\n                        position_ids_rmpad=None,\n                        sp_size=self.ulysses_sequence_parallel_size,\n                    )\n\n                input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)  # ((total_nnz / sp) + pad)\n\n                # only pass input_ids and position_ids to enable flash_attn_varlen\n                extra_args = {}\n                if self.use_fused_kernels:\n                    extra_args[\"temperature\"] = temperature\n                    extra_args[\"return_dict\"] = True\n\n                output = self.actor_module(\n                    input_ids=input_ids_rmpad,\n                    attention_mask=None,\n                    position_ids=position_ids_rmpad,\n                    **multi_modal_inputs,\n                    use_cache=False,\n                    **extra_args,\n                )  # prevent model thinks we are generating\n\n                if self.use_fused_kernels:\n                    log_probs = output.log_probs.squeeze(0)  # (total_nnz,)\n                    entropy_rmpad = output.entropy.squeeze(0)  # (total_nnz,)\n\n                else:\n                    logits_rmpad = output.logits.squeeze(0)  # (total_nnz, vocab_size)\n                    logits_rmpad.div_(temperature)\n\n                    # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)\n                    inplace_backward = True\n                    if calculate_entropy:\n                        inplace_backward = False\n                    log_probs = logprobs_from_logits(\n                        logits=logits_rmpad,\n                        labels=input_ids_rmpad_rolled,\n                        inplace_backward=inplace_backward,\n                    )\n\n                    # compute entropy\n                    if calculate_entropy:\n                        if not self.config.entropy_checkpointing:\n                            entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad)\n                        else:\n                            entropy_rmpad = torch.utils.checkpoint.checkpoint(\n                                self.compute_entropy_from_logits, logits_rmpad\n                            )\n\n                # gather log_prob if sp > 1\n                if self.use_ulysses_sp:\n                    # gather and unpad for the ulysses sp\n                    log_probs = gather_outpus_and_unpad(\n                        log_probs,\n                        gather_dim=0,\n                        unpad_dim=0,\n                        padding_size=pad_size,\n                    )\n                    if calculate_entropy:\n                        entropy_rmpad = gather_outpus_and_unpad(\n                            entropy_rmpad,\n                            gather_dim=0,\n                            unpad_dim=0,\n                            padding_size=pad_size,\n                        )\n                # pad back to (bsz, seqlen)\n                if calculate_entropy:\n                    full_entropy = pad_input(\n                        hidden_states=entropy_rmpad.unsqueeze(-1),\n                        indices=indices,\n                        batch=batch_size,\n                        seqlen=seqlen,\n                    )\n                full_log_probs = pad_input(\n                    hidden_states=log_probs.unsqueeze(-1),\n                    indices=indices,\n                    batch=batch_size,\n                    seqlen=seqlen,\n                )\n\n                # only return response part:\n                if calculate_entropy:\n                    entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1]  # (bsz, response_length)\n                log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1]  # (bsz, response_length)\n\n            else:  # not using rmpad and no ulysses sp\n                extra_args = {}\n                if self.use_fused_kernels:\n                    extra_args[\"temperature\"] = temperature\n                    extra_args[\"return_dict\"] = True\n                output = self.actor_module(\n                    input_ids=input_ids,\n                    attention_mask=attention_mask,\n                    position_ids=position_ids,\n                    **multi_modal_inputs,\n                    use_cache=False,\n                    **extra_args,\n                )  # prevent model thinks we are generating\n\n                if self.use_fused_kernels:\n                    log_probs = output.log_probs[:, -response_length - 1 : -1]\n                    entropy = output.entropy[:, -response_length - 1 : -1]  # (bsz, response_length)\n\n                else:\n                    logits = output.logits\n\n                    logits.div_(temperature)\n                    logits = logits[:, -response_length - 1 : -1, :]  # (bsz, response_length, vocab_size)\n                    log_probs = logprobs_from_logits(logits, micro_batch[\"responses\"])\n                    if calculate_entropy:\n                        if not self.config.entropy_checkpointing:\n                            entropy = F.entropy_from_logits(logits)  # (bsz, response_length)\n                        else:\n                            entropy = torch.utils.checkpoint.checkpoint(F.entropy_from_logits, logits)\n\n            return entropy, log_probs\n\n    def _optimizer_step(self):\n        assert self.config.grad_clip is not None\n        if isinstance(self.actor_module, FSDP):\n            grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip)\n        elif isinstance(self.actor_module, FSDPModule):\n            grad_norm = fsdp2_clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip)\n        else:\n            grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip)\n\n        if isinstance(grad_norm, DTensor):\n            grad_norm = grad_norm.full_tensor()\n\n        # if grad_norm is not finite, skip the update\n        if not torch.isfinite(grad_norm):\n            logger.warning(f\"WARN: rank {torch.distributed.get_rank()} grad_norm is not finite: {grad_norm}\")\n            self.actor_optimizer.zero_grad()\n        else:\n            self.actor_optimizer.step()\n        return grad_norm\n\n    @GPUMemoryLogger(role=\"dp actor\", logger=logger)\n    def compute_log_prob(self, data: TensorDict, calculate_entropy=False) -> torch.Tensor:\n        \"\"\"Compute the log probability of the responses given input_ids, attention_mask and position_ids\n\n        Args:\n            data (Tensordict): a Tensordict containing keys\n\n                ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the\n                concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.\n\n                ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``responses``:  tensor of shape [batch_size, response_length]. torch.int64.\n\n        Returns:\n            torch.Tensor: the log_prob tensor\n        \"\"\"\n        # set to eval\n        self.actor_module.eval()\n\n        micro_batch_size = data[\"micro_batch_size\"]\n        temperature = data[\"temperature\"]  # temperature must be in the data.meta_info to avoid silent error\n        use_dynamic_bsz = data[\"use_dynamic_bsz\"]\n        select_keys = [\"responses\", \"input_ids\", \"attention_mask\", \"position_ids\"]\n        batch = data.select(*select_keys)\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.keys()\n        multi_modal_inputs = {}\n        if has_multi_modal_inputs:\n            micro_batches = batch.split(micro_batch_size)\n            num_micro_batches = data.batch_size[0] // micro_batch_size\n            multi_modal_inputs = np.array_split(data[\"multi_modal_inputs\"], num_micro_batches, axis=0)\n            for i in range(num_micro_batches):\n                micro_batches[i][\"multi_modal_inputs\"] = multi_modal_inputs[i]\n        elif use_dynamic_bsz:\n            # split using dynamic bsz\n            max_token_len = data[\"max_token_len\"] * self.ulysses_sequence_parallel_size\n            micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)\n        else:\n            micro_batches = batch.split(micro_batch_size)\n\n        log_probs_lst = []\n        entropy_lst = []\n        for micro_batch in micro_batches:\n            with torch.no_grad():\n                entropy, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature, calculate_entropy=calculate_entropy)\n            log_probs_lst.append(log_probs)\n            if calculate_entropy:\n                entropy_lst.append(entropy)\n\n        log_probs = torch.concat(log_probs_lst, dim=0)\n        entropys = None\n        if calculate_entropy:\n            entropys = torch.concat(entropy_lst, dim=0)\n        if use_dynamic_bsz:\n            indices = list(itertools.chain.from_iterable(indices))\n            assert len(indices) == log_probs.size(0), f\"{len(indices)} vs. {log_probs.size()}\"\n            revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n            log_probs = log_probs[revert_indices]\n\n        return log_probs, entropys\n\n    @GPUMemoryLogger(role=\"dp actor\", logger=logger)\n    def update_policy(self, data: TensorDict):\n        # make sure we are in training mode\n        self.actor_module.train()\n\n        temperature = data[\"temperature\"]  # temperature must be in the data.meta_info to avoid silent error\n        select_keys = [\n            \"responses\",\n            \"response_mask\",\n            \"input_ids\",\n            \"attention_mask\",\n            \"position_ids\",\n            \"old_log_probs\",\n            \"advantages\",\n        ]\n        if self.config.use_kl_loss:\n            select_keys.append(\"ref_log_prob\")\n        batch = data.select(*select_keys)\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.keys()\n\n        # Split to make minibatch iterator for updating the actor\n        # See PPO paper for details. https://arxiv.org/abs/1707.06347\n        if has_multi_modal_inputs:\n            num_mini_batches = data.batch_size[0] // self.config.ppo_mini_batch_size\n            mini_batches = batch.split(self.config.ppo_mini_batch_size)\n            multi_modal_inputs = np.array_split(data[\"multi_modal_inputs\"], num_mini_batches, axis=0)\n            for i in range(num_mini_batches):\n                mini_batches[i][\"multi_modal_inputs\"] = multi_modal_inputs[i]\n            dataloader = mini_batches\n        else:\n            dataloader = batch.split(self.config.ppo_mini_batch_size)\n        on_policy = len(dataloader) == 1 and self.config.ppo_epochs == 1\n        metrics = {}\n        for epoch in range(self.config.ppo_epochs):\n            for batch_idx, data in enumerate(dataloader):\n                # split batch into micro_batches\n                mini_batch = data\n                if has_multi_modal_inputs:\n                    self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu           \n                    micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)\n                    num_micro_batches = mini_batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu\n                    multi_modal_inputs = np.array_split(mini_batch[\"multi_modal_inputs\"], num_micro_batches, axis=0)\n                    for i in range(num_micro_batches):\n                        micro_batches[i][\"multi_modal_inputs\"] = multi_modal_inputs[i]\n                elif self.config.use_dynamic_bsz:\n                    max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n                    micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len)\n                else:\n                    self.gradient_accumulation = (\n                        self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu\n                    )\n                    micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)\n\n                self.actor_optimizer.zero_grad()\n\n                for data in micro_batches:\n                    # Support all hardwares\n\n                    micro_batch_metrics = {}\n                    \n                    data = data.to(get_device_id())  # actor device is cpu when using offload\n                    response_mask = data[\"response_mask\"]\n                    old_log_prob = data[\"old_log_probs\"]\n                    advantages = data[\"advantages\"]\n\n                    entropy_coeff = self.config.entropy_coeff\n                    loss_agg_mode = self.config.loss_agg_mode\n                    use_cpgd_loss = self.config.use_cpgd_loss\n                    policy_drift_coeff = self.config.policy_drift_coeff\n\n                    if self.config.use_dynamic_bsz:\n                        loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size\n                    else:\n                        loss_scale_factor = 1 / self.gradient_accumulation\n\n                    # all return: (bsz, response_length)\n                    calculate_entropy = False\n                    if entropy_coeff != 0:\n                        calculate_entropy = True\n\n                    entropy, log_prob = self._forward_micro_batch(\n                        data, temperature=temperature, calculate_entropy=calculate_entropy\n                    )\n\n                    # for fully_async_policy recipe\n                    if hasattr(self.config, \"use_rollout_log_probs\") and self.config.use_rollout_log_probs:\n                        old_log_prob = data[\"old_log_probs\"]\n                    else:\n                        if on_policy:\n                            old_log_prob = log_prob.detach()\n                        else:\n                            old_log_prob = data[\"old_log_probs\"]\n\n                    loss_mode = self.config.policy_loss.loss_mode\n\n                    # Extract pre-computed rollout importance sampling weights if present\n                    # Weights are computed centrally in trainer and added when algorithm.rollout_is=True\n                    rollout_is_weights = data.get(\"rollout_is_weights\", None)\n\n                    # NOTE: Both mismatch diagnostic metrics (PPL, KL, etc.) and IS weight metrics\n                    # are computed centrally in ray_trainer.py for consistency and efficiency.\n\n                    # gpg -> core_algos.compute_policy_loss_gpg\n                    # clip_cov -> core_algos.compute_policy_loss_clip_cov\n\n                    policy_loss_fn = get_policy_loss_fn(loss_mode)\n\n                    # Compute policy loss (all functions return 4 values)\n                    pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(\n                        old_log_prob=old_log_prob,\n                        log_prob=log_prob,\n                        advantages=advantages,\n                        response_mask=response_mask,\n                        loss_agg_mode=loss_agg_mode,\n                        config=self.config,\n                        rollout_is_weights=rollout_is_weights,\n                    )\n\n                    if entropy_coeff != 0:\n                        entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n                        # compute policy loss\n                        policy_loss = pg_loss - entropy_loss * entropy_coeff\n                    else:\n                        policy_loss = pg_loss\n\n                    if self.config.use_kl_loss:\n                        ref_log_prob = data[\"ref_log_prob\"]\n                        # compute kl loss\n                        kld = kl_penalty(\n                            logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type\n                        )\n                        kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n\n                        policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef\n                        micro_batch_metrics[\"actor/kl_loss\"] = kl_loss.detach().item() * loss_scale_factor\n                        micro_batch_metrics[\"actor/kl_coef\"] = self.config.kl_loss_coef\n\n                    if use_cpgd_loss and policy_drift_coeff != 0:\n                        # compute policy drift loss from CPGD\n                        policy_drift = ((log_prob.detach() - old_log_prob).exp() - 1.0).clamp(max=2.0) * log_prob\n                        policy_drift_loss = agg_loss(\n                            loss_mat=policy_drift, loss_mask=response_mask, loss_agg_mode=loss_agg_mode\n                        )\n                        policy_loss = policy_loss + policy_drift_loss * policy_drift_coeff\n\n                    if self.config.use_dynamic_bsz:\n                        # relative to the dynamic bsz\n                        loss = policy_loss * loss_scale_factor\n                    else:\n                        loss = policy_loss * loss_scale_factor\n                    loss.backward()\n\n                    micro_batch_metrics.update(\n                        {\n                            \"actor/pg_clip_mean\": (log_prob - old_log_prob).exp().mean().detach().item(),\n                            \"actor/pg_clip_min\": (log_prob - old_log_prob).exp().min().detach().item(),\n                            \"actor/pg_clip_max\": (log_prob - old_log_prob).exp().max().detach().item(),\n                            \"actor/pg_loss\": pg_loss.detach().item() * loss_scale_factor,\n                            \"actor/pg_clipfrac\": pg_clipfrac.detach().item(),\n                            \"actor/ppo_kl\": ppo_kl.detach().item(),\n                            \"actor/pg_clipfrac_lower\": pg_clipfrac_lower.detach().item(),\n                        }\n                    )\n                    append_to_dict(metrics, micro_batch_metrics)\n\n                grad_norm = self._optimizer_step()\n                mini_batch_metrics = {\"actor/grad_norm\": grad_norm.detach().item()}\n                append_to_dict(metrics, mini_batch_metrics)\n        self.actor_optimizer.zero_grad()\n        return metrics\n"
  },
  {
    "path": "siirl/engine/actor/embodied_actor.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nSingle Process Actor\n\"\"\"\n\nimport itertools\nfrom typing import Iterable, Tuple\n\nimport torch\nfrom torch import nn\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom loguru import logger\nfrom tensordict import TensorDict\n\nfrom siirl.dag_worker import core_algos\nfrom siirl.engine.actor import BasePPOActor\nfrom siirl.utils.extras.py_functional import append_to_dict\nfrom siirl.utils.model_utils.torch_functional import logprobs_from_logits, log_probs_from_logits_all_rmpad\nfrom siirl.utils.model_utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx\nimport siirl.utils.model_utils.torch_functional as siirl_F\nfrom codetiming import Timer\nfrom flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis\n\n__all__ = ['RobDataParallelPPOActor']\n\n\n\nclass RobDataParallelPPOActor(BasePPOActor):\n\n    def __init__(\n        self,\n        config,\n        actor_module: nn.Module,\n        actor_optimizer: torch.optim.Optimizer = None,\n    ):\n        \"\"\"When optimizer is None, it is Reference Policy\"\"\"\n        super().__init__(config)\n        self.actor_module = actor_module\n        self.actor_optimizer = actor_optimizer\n        self.use_remove_padding = self.config.use_remove_padding\n        self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size\n        self.use_ulysses_sp = False #self.ulysses_sequence_parallel_size > 1\n        self.compute_entropy_from_logits = torch.compile(siirl_F.entropy_from_logits, dynamic=True)\n       \n    def process_tensor(self, tensor, pad_id):\n        mask = tensor != pad_id\n        if not torch.all(mask == mask[0:1], dim=1).all():\n            raise ValueError(\"Padding error!\")\n        base_mask = mask[0]\n        valid_len = base_mask.sum().item()\n        return tensor[:, base_mask], valid_len\n    \n    def generate_traj_mask(self, end_step, traj_len):\n        \"\"\"\n        Args:\n            end_step: (batch_size,), \n            traj_len: \n        Returns:\n            mask: (batch_size, traj_len),\n        \"\"\"\n        steps = torch.arange(traj_len, device=end_step.device)  # (traj_len,)\n        steps_expanded = steps.unsqueeze(0).expand(end_step.size(0), -1)\n        mask = steps_expanded < end_step.unsqueeze(1)  # (batch_size, traj_len)\n        return mask\n    \n    def apply_mask_with_grad_control(self, log_probs, entropy, mask):\n        \"\"\"\n        Args:\n            log_probs: (batch_size, traj_len, ...)\n            entropy:   (batch_size, traj_len, ...)\n            mask:      (batch_size, traj_len)\n        Returns:\n            log_probs_masked: \n            entropy_masked:   \n        \"\"\"\n        mask_expanded = mask.unsqueeze(-1)  \n\n        log_probs_masked = torch.where(\n            mask_expanded,\n            log_probs,\n            torch.zeros_like(log_probs, requires_grad=False)  \n        )\n\n        entropy_masked = torch.where(\n            mask_expanded,\n            entropy,\n            torch.zeros_like(entropy, requires_grad=False)   \n        )\n\n        return log_probs_masked, entropy_masked\n\n    def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        micro_batch:\n        \n        Returns: \n            entropy: # (bs, response_len)\n            log_probs: # (bs, response_len)\n        \"\"\"\n        batch_size = micro_batch['responses'].size(0)\n        traj_len = micro_batch['responses'].size(1)\n        tot_pad_len = micro_batch['input_ids'].size(2)\n        \n        assert all(micro_batch[key].size(0) == batch_size for key in ['responses', 'input_ids', 'attention_mask', 'pixel_values'])\n        assert all(micro_batch[key].size(1) == traj_len for key in ['responses', 'input_ids', 'attention_mask', 'pixel_values'])\n        assert all(micro_batch[key].size(2) == tot_pad_len for key in [ 'input_ids', 'attention_mask'])\n        \n            \n        response_length = micro_batch['responses'].size(-1) # 7*8\n        \n        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):\n            input_ids = micro_batch['input_ids']\n            attention_mask = micro_batch['attention_mask']\n            pixel_values = micro_batch[\"pixel_values\"]\n            responses = micro_batch[\"responses\"]\n            \n            input_ids = input_ids.reshape((batch_size * traj_len,) + input_ids.shape[2:])\n            attention_mask = attention_mask.reshape((batch_size * traj_len,) + attention_mask.shape[2:])\n            pixel_values = pixel_values.reshape((batch_size * traj_len,) + pixel_values.shape[2:])\n            responses = responses.reshape((batch_size * traj_len,) + responses.shape[2:])\n            \n            input_ids_unpad, _ = self.process_tensor(input_ids, self.pad_token_id)\n            attention_mask_unpad, _ = self.process_tensor(attention_mask, 0)\n            \n            if self.config.embodied_type == \"openvla-oft\":\n                logits = self.actor_module(input_ids=input_ids_unpad,\n                                        attention_mask=attention_mask_unpad,\n                                        pixel_values=pixel_values,\n                                        )  # prevent model thinks we are generating\n                \n                assert self.actor_module.vocab_size == 32000\n                start_index = self.actor_module.vocab_size - 256 \n                logits = logits[..., -256-64:-64]  # Shape: [batch_size, seq_len, 256]\n                responses = responses - start_index\n                #assert (0<=responses<=255).all()\n            \n                logits = logits.div(temperature) \n                \n                log_probs = logprobs_from_logits(logits, responses)\n                entropy = siirl_F.entropy_from_logits(logits)  # (bsz, response_length)\n            \n                assert len(log_probs.shape)==2 and len(entropy.shape)==2 \n                log_probs = log_probs.reshape((batch_size, traj_len*8,7) )\n                entropy = entropy.reshape((batch_size, traj_len*8,7) )\n\n                mask = self.generate_traj_mask(micro_batch['finish_step'], traj_len*8)\n                log_probs, entropy = self.apply_mask_with_grad_control(log_probs, entropy, mask)\n                \n                log_probs = log_probs.reshape((batch_size, traj_len*response_length))\n                entropy = entropy.reshape((batch_size, traj_len*response_length))\n                \n            elif self.config.embodied_type == \"openvla\":\n                output = self.actor_module(input_ids=input_ids_unpad,\n                                    attention_mask=attention_mask_unpad,\n                                    pixel_values=pixel_values,\n                                    use_cache=False)  # prevent model thinks we are generating\n                logits = output.logits\n                \n                logits = logits[:, -response_length - 1:-1]  # (bsz, response_length)\n                logits = logits.div(temperature) \n                \n                log_probs = logprobs_from_logits(logits, responses)\n                entropy = siirl_F.entropy_from_logits(logits)  # (bsz, response_length)\n                #ADD\n                \n                log_probs = log_probs.reshape((batch_size, traj_len,) + log_probs.shape[1:])\n                entropy = entropy.reshape((batch_size, traj_len,) + entropy.shape[1:])\n\n                \n                mask = self.generate_traj_mask(micro_batch['finish_step'], traj_len)\n                log_probs, entropy = self.apply_mask_with_grad_control(log_probs, entropy, mask)\n                \n                log_probs = log_probs.reshape((batch_size, traj_len*response_length))\n                entropy = entropy.reshape((batch_size, traj_len*response_length))\n                \n                \n\n            return entropy, log_probs\n    \n    def _forward_micro_batch_update(self, input_ids, attention_mask, pixel_values, responses, temperature) -> Tuple[torch.Tensor, torch.Tensor]:\n       \n        \n        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):\n            if self.config.embodied_type == \"openvla-oft\":\n                \n                input_ids_unpad, _ = self.process_tensor(input_ids, self.pad_token_id)\n                attention_mask_unpad, _ = self.process_tensor(attention_mask, 0)\n\n                \n                logits = self.actor_module(input_ids=input_ids_unpad,\n                                                attention_mask=attention_mask_unpad,\n                                                pixel_values=pixel_values,\n                                                )  \n                \n                assert logits.requires_grad \n                \n                assert self.actor_module.vocab_size == 32000\n                start_index = self.actor_module.vocab_size - 256 \n                logits = logits[..., -256-64:-64]  # Shape: [batch_size, seq_len, 256]\n                responses = responses - start_index\n                \n                logits = logits.div(temperature) \n                \n                log_probs = logprobs_from_logits(logits, responses)\n                entropy = siirl_F.entropy_from_logits(logits)  # (bsz, response_length)\n                \n                log_probs = log_probs.reshape((1, -1))\n                entropy = entropy.reshape((1, -1))\n                \n                return entropy, log_probs\n            \n            elif self.config.embodied_type == \"openvla\":\n                response_length = responses.size(-1)\n                input_ids_unpad, _ = self.process_tensor(input_ids, self.pad_token_id)\n                attention_mask_unpad, _ = self.process_tensor(attention_mask, 0)\n                output = self.actor_module(input_ids=input_ids_unpad,\n                                        attention_mask=attention_mask_unpad,\n                                        pixel_values=pixel_values,\n                                        use_cache=False)  # prevent model thinks we are generating\n                logits = output.logits\n                #\n                \n                logits = logits[:, -response_length - 1:-1]  # (bsz, response_length)\n                logits = logits.div(temperature) \n                \n                log_probs = logprobs_from_logits(logits, responses)\n                entropy = siirl_F.entropy_from_logits(logits)  # (bsz, response_length)\n                \n                \n                log_probs = log_probs.reshape((1, -1))\n                entropy = entropy.reshape((1, -1))\n\n                return entropy, log_probs\n                \n\n    def _forward_micro_batch_entropy(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]:\n        batch_size = micro_batch['responses'].size(0)\n        traj_len = micro_batch['responses'].size(1)\n        tot_pad_len = micro_batch['input_ids'].size(2)\n \n        assert all(micro_batch[key].size(0) == batch_size for key in ['responses', 'input_ids', 'attention_mask', 'pixel_values'])\n        assert all(micro_batch[key].size(1) == traj_len for key in ['responses', 'input_ids', 'attention_mask', 'pixel_values'])\n        assert all(micro_batch[key].size(2) == tot_pad_len for key in [ 'input_ids', 'attention_mask'])\n            \n        response_length = micro_batch['responses'].size(-1)\n        #assert response_length == 7*8\n        \n        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):\n            input_ids = micro_batch['input_ids']\n            #batch_size, seqlen = input_ids.shape\n            attention_mask = micro_batch['attention_mask']\n            pixel_values = micro_batch[\"pixel_values\"]\n            \n            input_ids = input_ids.reshape((batch_size * traj_len,) + input_ids.shape[2:])\n            attention_mask = attention_mask.reshape((batch_size * traj_len,) + attention_mask.shape[2:])\n            pixel_values = pixel_values.reshape((batch_size * traj_len,) + pixel_values.shape[2:])\n            \n            \n            input_ids_unpad, _ = self.process_tensor(input_ids, self.pad_token_id)\n            attention_mask_unpad, _ = self.process_tensor(attention_mask, 0)\n\n            if  self.config.embodied_type == \"openvla-oft\":\n            \n                logits = self.actor_module(input_ids=input_ids_unpad,\n                                                attention_mask=attention_mask_unpad,\n                                                pixel_values=pixel_values,\n                                                ) \n            \n                assert self.actor_module.vocab_size == 32000\n                start_index = self.actor_module.vocab_size - 256 \n                logits = logits[..., -256-64:-64]  # Shape: [batch_size, seq_len, 256]\n            \n                logits = logits.div(temperature) \n            \n                entropy = siirl_F.entropy_from_logits(logits)  # (bsz, response_length)\n\n                assert len(entropy.shape)==2 \n                entropy = entropy.reshape((batch_size, traj_len*8,7) )\n                mask = self.generate_traj_mask(micro_batch['finish_step'], traj_len*8)\n                _, entropy = self.apply_mask_with_grad_control(entropy, entropy, mask)\n                entropy = entropy.reshape((batch_size, traj_len*response_length))\n                return entropy\n            \n            elif self.config.embodied_type == \"openvla\":\n                output = self.actor_module(input_ids=input_ids_unpad,\n                                        attention_mask=attention_mask_unpad,\n                                        pixel_values=pixel_values,\n                                        use_cache=False)  # prevent model thinks we are generating\n                logits = output.logits\n                #\n                \n                \n                logits = logits[:, -response_length - 1:-1]  # (bsz, response_length)\n                logits = logits.div(temperature) \n                \n                entropy = siirl_F.entropy_from_logits(logits)  # (bsz, response_length)\n                #ADD\n\n                entropy = entropy.reshape((batch_size, traj_len,) + entropy.shape[1:])\n                mask = self.generate_traj_mask(micro_batch['finish_step'], traj_len)\n                _, entropy = self.apply_mask_with_grad_control(entropy, entropy, mask)\n                entropy = entropy.reshape((batch_size, traj_len*response_length))\n                return entropy\n\n\n    def _optimizer_step(self):\n        assert self.config.grad_clip is not None\n\n        if isinstance(self.actor_module, FSDP):\n            grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip)\n        else:\n            grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip)\n        self.actor_optimizer.step()\n        return grad_norm\n\n    def compute_log_prob(self, data: TensorDict, calculate_entropy=False) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Compute the log probability of the responses given input_ids, attention_mask and position_ids\n\n        Args:\n            data (TensorDict): a TensorDict containing keys\n\n                ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the\n                concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.\n\n                ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``responses``:  tensor of shape [batch_size, response_length]. torch.int64.\n\n        Returns:\n            torch.Tensor: the log_prob tensor\n            torch.Tensor: the entropy tensor\n        \"\"\"\n        self.actor_module.eval()\n\n        micro_batch_size = data['micro_batch_size'] #256\n        temperature = data['temperature']  # temperature must be in the data.meta_info to avoid slient error # 1\n        use_dynamic_bsz = data['use_dynamic_bsz'] #trues\n        self.pad_token_id = data['pad_token_id']\n        \n        # Note: finish_step is 1D and only needed for reward computation, not for log_prob\n        select_keys = ['responses', 'input_ids', 'attention_mask', 'pixel_values','finish_step']\n        batch = data.select(*select_keys)\n\n        if use_dynamic_bsz:\n            # split using dynamic bsz\n            max_token_len = data['max_token_len'] * self.ulysses_sequence_parallel_size\n            micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)\n        else:\n            micro_batches = batch.split(micro_batch_size)\n\n        log_probs_lst = []\n        for i, micro_batch in enumerate(micro_batches):\n            with torch.no_grad():\n                _, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature)\n            log_probs_lst.append(log_probs)\n        log_probs = torch.concat(log_probs_lst, dim=0)\n\n        if use_dynamic_bsz:\n            indices = list(itertools.chain.from_iterable(indices))\n            assert len(indices) == log_probs.size(0), f\"{len(indices)} vs. {log_probs.size()}\"\n            revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n            log_probs = log_probs[revert_indices]\n\n        return log_probs, None # TODO: implement entropy computation\n\n    def update_policy(self, data: TensorDict):\n        self.actor_module.train()\n\n        assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0\n        self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu\n        temperature = data['temperature']  # temperature must be in the data.meta_info to avoid slient error\n\n        select_keys = ['responses', 'input_ids', 'attention_mask', 'pixel_values', 'old_log_probs', 'advantages',\"finish_step\"]\n        batch = data.select(*select_keys)\n\n\n        # Split to make minibatch iterator for updating the actor\n        # See PPO paper for details. https://arxiv.org/abs/1707.06347\n        dataloader = batch.split(self.config.ppo_mini_batch_size)\n        metrics = {}\n        for batch_idx, data in enumerate(dataloader):\n            # split batch into micro_batches\n            mini_batch = data\n            if self.config.use_dynamic_bsz:\n                max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n                micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)\n            else:\n                # split batch into micro_batches\n                micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)\n\n            self.actor_optimizer.zero_grad()\n\n            for test_idx, data in enumerate(micro_batches):\n                data = data.cuda()  # actor device is cpu when using offload\n                responses = data['responses']\n                \n                response_length = responses.size(1) *  responses.size(2)\n                finish_step = data['finish_step'] * self.config.action_token_len\n                steps = torch.arange(response_length, device=data['responses'].device)  # (traj_len,)\n                steps_expanded = steps.unsqueeze(0).expand(data['responses'].size(0), -1)\n                response_mask = steps_expanded < finish_step.unsqueeze(1)  # (batch_size, traj_len)\n                \n                response_mask_sum = response_mask.sum(axis=None)\n\n                old_log_prob = data['old_log_probs']\n                advantages = data['advantages']\n                \n                #clip_ratio = self.config.clip_ratio\n                clip_ratio_high = self.config.clip_ratio_high\n                clip_ratio_low = self.config.clip_ratio_low\n                entropy_coeff = self.config.entropy_coeff\n\n                batch_size = data['responses'].size(0)\n                traj_len = data['responses'].size(1)\n                tot_pad_len = data['input_ids'].size(2)\n                \n                \n                input_ids = data['input_ids']\n                attention_mask = data['attention_mask']\n                pixel_values = data[\"pixel_values\"]\n                responses = data[\"responses\"]\n                \n                \n                input_ids = input_ids.reshape((batch_size * traj_len,) + input_ids.shape[2:])\n                attention_mask = attention_mask.reshape((batch_size * traj_len,) + attention_mask.shape[2:])\n                pixel_values = pixel_values.reshape((batch_size * traj_len,) + pixel_values.shape[2:])\n                responses = responses.reshape((batch_size * traj_len,) + responses.shape[2:])\n                \n                loss_info = {\n                    #'actor/entropy_loss': entropy_loss.detach().item(),\n                    'actor/pg_loss':0,\n                    'actor/pg_clipfrac': 0,\n                    'actor/ppo_kl': 0,\n                }\n                \n                assert traj_len % self.config.traj_mini_batch_size ==0\n                traj_split_num = int(traj_len/self.config.traj_mini_batch_size)\n                \n                \n    \n\n                for i in range(0, traj_len, int(traj_len/traj_split_num)):\n                    entropy, log_prob = self._forward_micro_batch_update(input_ids=input_ids[i:i+int(traj_len/traj_split_num)], \n                                                                         attention_mask=attention_mask[i:i+int(traj_len/traj_split_num)], \n                                                                         pixel_values=pixel_values[i:i+int(traj_len/traj_split_num)], \n                                                                         responses=responses[i:i+int(traj_len/traj_split_num)], \n                                                                         temperature=temperature)\n\n                    slice_id = i*self.config.action_token_len*self.config.action_chunks_len\n                    next_slice_id = (i+int(traj_len/traj_split_num))*self.config.action_token_len*self.config.action_chunks_len\n                    old_log_prob_tmp = old_log_prob[:, slice_id: next_slice_id]\n                    advantages_tmp = advantages[:, slice_id: next_slice_id]\n                    response_mask_tmp = response_mask[:, slice_id: next_slice_id]\n                        \n                    pg_loss, pg_clipfrac, ppo_kl, _ = core_algos.compute_policy_loss_vanilla(old_log_prob=old_log_prob_tmp,\n                                                                            log_prob=log_prob,\n                                                                            advantages=advantages_tmp,\n                                                                            response_mask=response_mask_tmp,\n                                                                            config=self.config)\n                    \n                    response_mask_tmp_sum = response_mask_tmp.sum(axis=None)\n                    pg_loss = pg_loss* response_mask_tmp_sum\n                    pg_clipfrac = pg_clipfrac* response_mask_tmp_sum / response_mask_sum\n                    ppo_kl = ppo_kl* response_mask_tmp_sum / response_mask_sum\n                    \n                    policy_loss = pg_loss / response_mask_sum\n                    \n                    loss = policy_loss / self.gradient_accumulation\n                    \n                    loss.backward()\n                    \n                    loss_info['actor/pg_loss'] =  loss_info['actor/pg_loss'] + policy_loss.detach().item()\n                    loss_info['actor/pg_clipfrac'] = loss_info['actor/pg_clipfrac'] + pg_clipfrac.detach().item()\n                    loss_info['actor/ppo_kl'] = loss_info['actor/ppo_kl'] +  ppo_kl.detach().item()\n\n                append_to_dict(metrics, loss_info)\n               \n            grad_norm = self._optimizer_step()\n            data = {'actor/grad_norm': grad_norm.detach().item()}\n            append_to_dict(metrics, data)\n            torch.cuda.empty_cache()\n        self.actor_optimizer.zero_grad()\n        torch.cuda.synchronize()\n        torch.distributed.barrier()\n        torch.cuda.empty_cache()\n\n        return metrics\n\n    \n    def compute_entropy(self, bacth_data: TensorDict):\n        \n        if bacth_data['train_mode'] ==True:\n            self.actor_module.train()\n        else:\n            self.actor_module.eval()\n\n        assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0\n        self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu\n        temperature = bacth_data['temperature']  # temperature must be in the data.meta_info to avoid slient error\n\n        select_keys = ['responses', 'input_ids', 'attention_mask', 'pixel_values', \"finish_step\"]\n        batch = bacth_data.select(*select_keys)\n\n        # Split to make minibatch iterator for updating the actor\n        # See PPO paper for details. https://arxiv.org/abs/1707.06347\n        dataloader = batch.split(self.config.ppo_mini_batch_size)\n        \n        metrics = {}\n        for batch_idx, data in enumerate(dataloader):\n            # split batch into micro_batches\n            mini_batch = data\n            if self.config.use_dynamic_bsz:\n                max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n                micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)\n            else:\n                # split batch into micro_batches\n                micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)\n\n            for data in micro_batches:\n                data = data.cuda()  # actor device is cpu when using offload\n                responses = data['responses']\n                response_length = responses.size(1) *  responses.size(2)\n                finish_step = data['finish_step'] * self.config.action_token_len\n                steps = torch.arange(response_length, device=data['responses'].device)  # (traj_len,)\n                steps_expanded = steps.unsqueeze(0).expand(data['responses'].size(0), -1)\n                response_mask = steps_expanded < finish_step.unsqueeze(1)  # (batch_size, traj_len)\n                \n\n                with torch.no_grad():\n                    entropy = self._forward_micro_batch_entropy(micro_batch=data, temperature=temperature)\n                    entropy_loss = siirl_F.masked_mean(entropy, response_mask)\n\n                if bacth_data['is_filtered'] and bacth_data['train_mode']:\n                    data = {\n                        'actor_after/entropy_loss_train': entropy_loss.detach().item(),\n                    }\n                    append_to_dict(metrics, data)\n                elif bacth_data['is_filtered'] and not bacth_data['train_mode']:\n                    data = {\n                        'actor_after/entropy_loss_eval': entropy_loss.detach().item(),\n                    }\n                    append_to_dict(metrics, data)\n                elif not bacth_data['is_filtered'] and bacth_data['train_mode']:\n                    data = {\n                        'actor_before/entropy_loss_train': entropy_loss.detach().item(),\n                    }\n                    append_to_dict(metrics, data)\n                elif not bacth_data['is_filtered'] and not bacth_data['train_mode']:\n                    data = {\n                        'actor_before/entropy_loss_eval': entropy_loss.detach().item(),\n                    }\n                    append_to_dict(metrics, data)\n                        \n                \n        torch.cuda.synchronize()\n        torch.distributed.barrier()\n        torch.cuda.empty_cache()\n        return metrics"
  },
  {
    "path": "siirl/engine/actor/megatron_actor.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nMegatron Actor.\nIn megatron actor, the differences are:\n1. We only make minibatch\n\nNote that our model doesn't have to be `MegatronModule` because we don't share embedding in the last layer\n\"\"\"\n\nimport itertools\nfrom functools import partial\nfrom typing import Iterable\n\nimport torch\nimport torch.distributed\nimport numpy as np\nfrom loguru import logger\nfrom megatron.core import parallel_state as mpu\n\n# from megatron.core.optimizer import DistributedOptimizer\nfrom megatron.core.optimizer import DistributedOptimizer\nfrom megatron.core.pipeline_parallel import get_forward_backward_func\nfrom torch import nn\nfrom tensordict import TensorDict\n\nfrom siirl.utils.debug import GPUMemoryLogger\nfrom siirl.utils.debug.profile import Profiler\nfrom siirl.utils.extras.device import get_device_id, get_torch_device\nfrom siirl.utils.extras.py_functional import append_to_dict\nfrom siirl.utils.megatron.megatron_utils import get_model_config\nfrom siirl.utils.megatron.pipeline_parallel import make_batch_generator\nfrom siirl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits\nfrom siirl.utils.model_utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\nfrom siirl.utils.model_utils.torch_functional import broadcast_dict_tensor\nfrom siirl.engine.actor import BasePPOActor\nfrom siirl.dag_worker.core_algos import agg_loss, get_policy_loss_fn, kl_penalty\n\n__all__ = [\"MegatronPPOActor\"]\n\n\nclass MegatronPPOActor(BasePPOActor):\n    def __init__(\n        self,\n        config,\n        model_config,\n        hf_config,\n        tf_config,\n        actor_module: nn.ModuleList,\n        actor_optimizer: DistributedOptimizer,\n    ):\n        \"\"\"MeagtronPPOActor class. This class implements the simple PPO logics when the model is built with Megatron.\n\n        Args:\n            config (OmegaConf): the basic config that contains the hyper-parameters of PPO Actor. It must contain\n\n                ``ppo_micro_batch_size_per_gpu``: micro batch size when updating ppo.\n\n                ``ppo_mini_batch_size``: minibatch size when updating ppo using the batch data.\n\n                ``ppo_epochs``: number of epochs to update the actor using the batch data.\n\n                ``shuffle``: whether to shuffle the data after each ppo epoch.\n\n                ``clip_ratio``: clip ratio of the ppo algorithm. See https://arxiv.org/abs/1707.06347.\n\n                ``entropy_coeff``: entropy coefficient of the PPO loss. See https://arxiv.org/abs/1707.06347.\n            model_config (OmegaConf): model configuration. It must contains ``model_config.vocab_size`` and\n                ``model_config.hidden_size``\n            hf_config (PretrainedConfig): huggingface config\n            tf_config (TransformerConfig): mcore transformer config\n            actor_module (nn.ModuleList): actor module is a ModuleList that contains a list of nn.Module in this\n                pp stage.\n                each nn.Module in this rank holds a vpp module chunk. See https://arxiv.org/pdf/2104.04473.pdf for\n                more details.\n                The actor module has some constraints to follow in order to use the updating logics implemented here\n\n                1. It must implement unpad_input before any computation and pad_input after all the computation.\n                Remove padding is an\n                optimization that removes the padding tokens. See unpad_input and pad_input function in flash-attn\n                (https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py).\n\n                2. Each pp stage must return the hidden state with the same shape [total_nnz, 1, hidden_size],\n                where total_nnz is the number of valid tokens in this batch. If sequence parallel is enabled, the size\n                of the hidden state is [total_nnz // tp, 1, hidden_size].\n            actor_optimizer (DistributedOptimizer): currently, we only support DistributedOptimizer in Megatron.\n                It implements\n                zero1 optimizer that shards the optimizer state across dp ranks.\n\n        >>> from megatron.training import get_model\n        >>> from megatron.optimizer import get_megatron_optimizer\n        >>> actor_module = get_model(megatron_actor_model_provider, wrap_with_ddp=True)\n        >>> actor_module = nn.ModuleList(actor_module)\n        >>> actor_optimizer = get_megatron_optimizer(actor_module)\n        >>> actor = MegatronPPOActor(config=config,\n        >>>                          model_config=actor_model_config,\n        >>>                          hf_config=hf_config,\n        >>>                          tf_config=tf_config,\n        >>>                          actor_module=actor_module,\n        >>>                          actor_optimizer=actor_optimizer)\n        \"\"\"\n        super().__init__(config)\n        self._validate_config(config)\n        self.model_config = model_config\n        self.hf_config = hf_config\n        self.tf_config = tf_config\n        self.actor_module = actor_module\n        self.actor_optimizer: DistributedOptimizer = actor_optimizer\n        self.use_torch_profiler = self.config.profile.get(\"tool\") == \"torch\"\n        if self.use_torch_profiler:\n            self.prof = Profiler(\n                self.config.profile, tool_config=self.config.profile.get(\"tool_config\", {}).get(\"torch\", {})\n            )\n        else:\n            self.prof = None\n        self.use_fused_kernels = self.config.use_fused_kernels\n        if self.use_fused_kernels:\n            from siirl.models.mcore.model_forward_fused import patch_fused_forward\n\n            for model in self.actor_module:\n                patch_fused_forward(model)\n\n        config = get_model_config(self.actor_module[0])\n        if torch.distributed.get_rank() == 0:\n            print(config)\n\n    def _validate_config(self, config) -> None:\n        \"\"\"Validate config options not implemented for Megatron backend\"\"\"\n        assert config.ulysses_sequence_parallel_size == 1\n        if config.shuffle:\n            assert config.data_loader_seed is not None, \"If shuffle dataloader, seed must be manually set\"\n        if config.megatron.tensor_model_parallel_size == 1:\n            print(\"[Warning] Because actor tp size == 1, set sp to False\")\n            config.megatron.sequence_parallel = False\n        self.config = config\n\n    @GPUMemoryLogger(role=\"megatron actor\", logger=logger)\n    def compute_log_prob(self, data: TensorDict, calculate_entropy=False) -> torch.Tensor:\n        \"\"\"Compute the log probability of the responses given input_ids, attention_mask and position_ids\n\n        Args:\n            data (TensorDict): a TensorDict containing keys\n\n                ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the\n                concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.\n\n                ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.\n\n                ``responses``:  tensor of shape [batch_size, response_length]. torch.int64.\n\n        Returns:\n            TensorDict: torch.Tensor: the log_prob tensor\n        \"\"\"\n        use_dynamic_bsz = data[\"use_dynamic_bsz\"]\n        micro_batch_size = data[\"micro_batch_size\"]\n        max_token_len = data[\"max_token_len\"]\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"max_token_len must be set when use_dynamic_bsz is True\"\n            max_token_len = max_token_len * self.config.megatron.context_parallel_size\n        else:\n            assert micro_batch_size is not None, (\n                \"micro batch size is needed for forward compute when use_dynamic_bsz is False\"\n            )\n\n        # We make recompute_old_log_prob by default here.\n        # TODO (zhangchi.usc1992): actually, this function should only return log_prob and this logic should be\n        # handled by user outside\n        entropys = torch.Tensor()\n\n        select_keys = [\"responses\", \"input_ids\", \"attention_mask\", \"position_ids\"]\n        batch = data.select(*select_keys)\n        input_ids = data[\"input_ids\"]\n        batch_size = input_ids.size(0)\n        response = batch[\"responses\"]\n        temperature = data[\"temperature\"]\n        response_length = response.size(1)\n        with torch.no_grad():\n            output = self.forward_backward_batch(\n                batch,\n                temperature=temperature,\n                forward_only=True,\n                calculate_entropy=calculate_entropy,\n                use_dynamic_bsz=use_dynamic_bsz,\n                micro_batch_size=micro_batch_size,\n                max_token_len=max_token_len,\n            )\n            if mpu.is_pipeline_last_stage(ignore_virtual=True):\n                # only on last rank. It should be on every tp rank\n                log_probs = [o[\"log_probs\"] for o in output[\"output\"]]  # (bs, seq_size)\n                log_probs = torch.cat(log_probs, dim=0).to(torch.float32)\n\n                if calculate_entropy:\n                    entropys = torch.cat([o[\"entropy\"] for o in output[\"output\"]], dim=0)\n                    entropys = entropys.to(torch.float32)\n\n                if use_dynamic_bsz:\n                    indices = output[\"indices\"]\n                    indices = list(itertools.chain.from_iterable(indices))\n                    assert len(indices) == log_probs.size(0), f\"{len(indices)} vs. {log_probs.size()}\"\n                    revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n                    log_probs = log_probs[revert_indices]\n                    if calculate_entropy:\n                        assert len(indices) == entropys.size(0), f\"{len(indices)} vs. {entropys.size()}\"\n                        entropys = entropys[revert_indices]\n            else:\n                # other pp ranks\n                log_probs = torch.empty(\n                    size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device\n                )\n                if calculate_entropy:\n                    entropys = torch.empty(\n                        size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device\n                    )\n\n            log_probs = log_probs.to(get_device_id())\n            # broadcast across pp ranks\n            torch.distributed.broadcast(\n                tensor=log_probs,\n                src=mpu.get_pipeline_model_parallel_last_rank(),\n                group=mpu.get_pipeline_model_parallel_group(),\n                async_op=False,\n            )\n            log_probs = log_probs.to(\"cpu\")\n\n            if calculate_entropy:\n                entropys = entropys.to(get_device_id())\n                torch.distributed.broadcast(\n                    tensor=entropys,\n                    src=mpu.get_pipeline_model_parallel_last_rank(),\n                    group=mpu.get_pipeline_model_parallel_group(),\n                    async_op=False,\n                )\n                entropys = entropys.to(\"cpu\")\n\n        # add empty cache after each compute\n        get_torch_device().empty_cache()\n\n        return log_probs, entropys\n\n    def compute_ppo_loss(self, model_output, data):\n        log_prob = model_output[\"log_probs\"]\n        entropy = model_output.get(\"entropy\", None)\n\n        metrics = {}\n\n        response_mask = data[\"response_mask\"].to(bool)\n        # compute policy loss\n        old_log_prob = data[\"old_log_probs\"]\n        advantages = data[\"advantages\"]\n\n        loss_agg_mode = self.config.loss_agg_mode\n\n        loss_mode = self.config.policy_loss.loss_mode\n\n        policy_loss_fn = get_policy_loss_fn(loss_mode)\n        pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(\n            old_log_prob=old_log_prob,\n            log_prob=log_prob,\n            advantages=advantages,\n            response_mask=response_mask,\n            loss_agg_mode=loss_agg_mode,\n            config=self.config,\n        )\n\n        metrics.update(\n            {\n                \"actor/pg_loss\": pg_loss.detach().item(),\n                \"actor/pg_clipfrac\": pg_clipfrac.detach().item(),\n                \"actor/ppo_kl\": ppo_kl.detach().item(),\n                \"actor/pg_clipfrac_lower\": pg_clipfrac_lower.detach().item(),\n            }\n        )\n        policy_loss = pg_loss\n\n        # add entropy loss\n        if entropy is not None:\n            entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)\n            entropy_coeff = self.config.entropy_coeff\n            policy_loss -= entropy_coeff * entropy_loss\n\n        # add kl loss\n        if self.config.use_kl_loss:\n            ref_log_prob = data[\"ref_log_prob\"]\n            # compute kl loss\n            kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type)\n            kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode)\n\n            policy_loss += kl_loss * self.config.kl_loss_coef\n            metrics[\"actor/kl_loss\"] = kl_loss.detach().item()\n            metrics[\"actor/kl_coef\"] = self.config.kl_loss_coef\n\n        return policy_loss, metrics\n\n    def forward_backward_batch(\n        self,\n        data: TensorDict,\n        temperature: float,  \n        forward_only=False,\n        calculate_entropy=False,\n        use_dynamic_bsz=False,\n        micro_batch_size=None,\n        max_token_len=None,\n    ):\n        \"\"\"\n        We assume:\n        - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input\n        - The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled\n        \"\"\"\n        # broadcast from last pp rank to all other pp ranks\n        # TODO: actually, we just need to control the sampling order.\n        data.to(get_device_id())\n        data = data.contiguous()\n        mini_batch = data\n        broadcast_dict_tensor(\n            mini_batch,\n            src=mpu.get_pipeline_model_parallel_last_rank(),\n            group=mpu.get_pipeline_model_parallel_group(),\n        )\n        \n        # broadcast from cp rank 0 to all other cp ranks to ensure same data across CP group\n        cp_size = mpu.get_context_parallel_world_size()\n        if cp_size > 1:\n            # Get the first rank in this CP group (cp_rank=0)\n            cp_group = mpu.get_context_parallel_group()\n            cp_group_ranks = torch.distributed.get_process_group_ranks(cp_group)\n            src_rank = cp_group_ranks[0]  # cp_rank=0 in this group\n            broadcast_dict_tensor(\n                mini_batch,\n                src=src_rank,\n                group=cp_group,\n            )\n        mini_batch.to(\"cpu\")\n        # split into micro-batches\n        mini_batch[\"attention_mask\"] = mini_batch[\"attention_mask\"].to(bool)\n        self.has_multi_modal_inputs = \"multi_modal_inputs\" in mini_batch.keys()\n        if self.has_multi_modal_inputs:\n            mini_batch.batch[\"multi_modal_inputs\"] = mini_batch.non_tensor_batch[\"multi_modal_inputs\"]\n            mini_batch.batch[\"multi_modal_inputs_idx\"] = torch.Tensor(\n                list(range(len(mini_batch.non_tensor_batch[\"multi_modal_inputs\"])))\n            ).to(torch.int64)\n\n        if mini_batch[\"position_ids\"].dim() == 3:  # qwen2vl mrope [bs, 3, seq_len]\n            mini_batch[\"position_ids\"] = mini_batch[\"position_ids\"][\n                :, 0\n            ]  # mcore patch recompute qwen2vl's pos ids during forward\n\n        indices = None\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"max_token_len must be set when use_dynamic_bsz is True\"\n            vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n            if vpp_size is not None and vpp_size > 1:\n                microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage\n                micro_batches, indices = rearrange_micro_batches(\n                    batch=mini_batch,\n                    num_batches_divided_by=microbatch_group_size_per_vp_stage,\n                    max_token_len=max_token_len,\n                )\n                assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, (\n                    f\"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage \"\n                    f\"{microbatch_group_size_per_vp_stage} for megatron backend\"\n                )\n            else:\n                micro_batches, indices = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)\n        else:\n            assert micro_batch_size is not None, (\n                \"micro_batch_size is needed to be passed in when not using dynamic batch size\"\n            )\n            micro_batches = mini_batch.split(micro_batch_size)\n        # compute input shapes for pp stages\n        n_micro_batch = len(micro_batches)\n\n        forward_backward_func = get_forward_backward_func()\n\n        def loss_func(output, data):\n            # For memory efficiency\n            # We move calculation of entropy to compute_log_probs, forward_only == True\n            device = output[\"log_probs\"].device\n\n            responses = data[\"responses\"]\n            response_length = responses.size(1)\n\n            log_prob = output[\"log_probs\"][:, -response_length - 1 : -1].contiguous()\n            model_output = {\"log_probs\": log_prob}\n            if calculate_entropy:\n                entropy = output[\"entropy\"][:, -response_length - 1 : -1].contiguous()\n                model_output[\"entropy\"] = entropy\n\n            if forward_only:\n                # for inference\n                return torch.tensor(1.0, device=device), model_output\n\n            # for training\n            # note that this loss function can be swapped with other loss functions such as SFT\n            policy_loss, metrics = self.compute_ppo_loss(model_output, data)\n\n            # return loss and stats\n            return policy_loss, metrics\n\n        def forward_step(batch_iter, model):\n            batch = next(batch_iter)\n            batch = batch.to(get_device_id())\n            batch = batch.contiguous()\n\n            input_ids = batch[\"input_ids\"]\n            attention_mask = batch[\"attention_mask\"].to(bool)\n            position_ids = batch[\"position_ids\"]\n\n            multi_modal_inputs = {}\n            if \"multi_modal_inputs\" in batch:\n                for key in batch[\"multi_modal_inputs\"][0].keys():\n                    idxs = batch[\"multi_modal_inputs_idx\"]\n                    mmi = batch[\"multi_modal_inputs\"]\n                    multi_modal_inputs[key] = torch.cat(\n                        [mmi[idx].get(key) for idx in idxs if mmi[idx].get(key) is not None], dim=0\n                    )\n            responses = batch[\"responses\"]\n            response_length = responses.size(1)\n            label = position_ids.clone()\n            label[:, -response_length - 1 : -1] = responses\n            label_mask = attention_mask.clone()\n            label_mask[:, : -response_length - 1] = False\n            label_mask[:, -1] = False\n\n            from siirl.models.mcore import get_mcore_forward_fn, get_mcore_forward_fused_fn\n\n            if self.use_fused_kernels:\n                forward_fn = get_mcore_forward_fused_fn(self.hf_config)\n                # return dict of [logits, entropy]\n                output = forward_fn(\n                    model,\n                    input_ids,\n                    position_ids,\n                    attention_mask,\n                    sequence_parallel=self.tf_config.sequence_parallel,\n                    multi_modal_inputs=multi_modal_inputs,\n                    labels=label,\n                    labels_mask=label_mask,\n                    temperature=temperature,\n                )\n            else:\n                forward_fn = get_mcore_forward_fn(self.hf_config)\n\n                def logits_processor(logits, label, label_mask):\n                    assert logits.shape[:2] == label.shape[:2]\n                    assert label.shape == label_mask.shape\n                    logits.div_(temperature)\n                    ret = {}\n                    if calculate_entropy:\n                        logits_bak = logits.clone()\n                        entropy = vocab_parallel_entropy(logits)\n                        ret[\"entropy\"] = entropy\n                    else:\n                        logits_bak = logits\n                    log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label)\n                    log_probs = log_probs.masked_fill(~label_mask, 0.0)\n                    ret[\"log_probs\"] = log_probs\n                    return ret\n\n                logits_processor_args = {\"label\": label, \"label_mask\": label_mask}\n                output = forward_fn(\n                    model,\n                    input_ids,\n                    attention_mask,\n                    position_ids,\n                    sequence_parallel=self.tf_config.sequence_parallel,\n                    multi_modal_inputs=multi_modal_inputs,\n                    logits_processor=logits_processor,\n                    logits_processor_args=logits_processor_args,\n                )\n\n            return output, partial(loss_func, data=batch)\n\n        # batch should be a list of batches inside micro-batches\n        batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.actor_module))\n\n        # TODO: we may use the new schedule instead\n        # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)\n        losses_reduced = forward_backward_func(\n            forward_step_func=forward_step,\n            data_iterator=batch_generator,\n            model=self.actor_module,\n            num_microbatches=n_micro_batch,\n            seq_length=1,  # the communication shape is obtained via p2p comm\n            micro_batch_size=1,  # the communication shape is obtained via p2p comm\n            forward_only=forward_only,\n        )\n        # loss_reduces contains the stats returned from loss_func\n\n        if self.has_multi_modal_inputs:\n            data.batch.pop(\"multi_modal_inputs\")\n            data.batch.pop(\"multi_modal_inputs_idx\")\n            data.non_tensor_batch.pop(\"multi_modal_inputs\")\n\n        losses_reduced = {\"output\": losses_reduced}\n        if use_dynamic_bsz:\n            losses_reduced[\"indices\"] = indices\n        return losses_reduced\n\n    @GPUMemoryLogger(role=\"megatron actor\", logger=logger)\n    def update_policy(self, data:TensorDict) -> dict:\n        \"\"\"Update the policy with an iterator of TensorDict\n\n        Args:\n            ddata (TensorDict): TensorDict \n        Returns:\n            Dict: a dictionary containing the statistics. Note that the statistics are only valid in the last pp stage\n            and users have to combine the output in each dp rank manually.\n\n        \"\"\"\n        metrics = {}\n        temperature = data[\"temperature\"]  # temperature must be in the data.meta_info to avoid silent error\n        select_keys = [\n            \"responses\",\n            \"response_mask\",\n            \"input_ids\",\n            \"attention_mask\",\n            \"position_ids\",\n            \"old_log_probs\",\n            \"advantages\",\n        ]\n        if self.config.use_kl_loss:\n            select_keys.append(\"ref_log_prob\")\n        batch = data.select(*select_keys)\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.keys()\n\n        # Split to make minibatch iterator for updating the actor\n        # See PPO paper for details. https://arxiv.org/abs/1707.06347\n        if has_multi_modal_inputs:\n            num_mini_batches = data.batch_size[0] // self.config.ppo_mini_batch_size\n            mini_batches = batch.split(self.config.ppo_mini_batch_size)\n            multi_modal_inputs = np.array_split(data[\"multi_modal_inputs\"], num_mini_batches, axis=0)\n            for i in range(num_mini_batches):\n                mini_batches[i][\"multi_modal_inputs\"] = multi_modal_inputs[i]\n            dataloader = mini_batches\n        else:\n            dataloader = batch.split(self.config.ppo_mini_batch_size)\n        \n        \n        if self.use_torch_profiler and self.prof and self.prof.enable:\n            self.prof.start()\n        for data in dataloader:\n            self.actor_optimizer.zero_grad()\n            # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm\n            for chunk in self.actor_module:\n                # if use distributed optimizer, zero grad buffer will be handled by optimizer\n                chunk.zero_grad_buffer()\n\n            calculate_entropy = self.config.entropy_coeff != 0\n            if data.get(\"micro_batch_size\", None) is not None:\n                micro_batch_size = data[\"micro_batch_size\"]\n            else:\n                micro_batch_size = self.config.ppo_micro_batch_size_per_gpu\n            max_token_len = None\n            if self.config.use_dynamic_bsz:\n                max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size\n            metric_micro_batch = self.forward_backward_batch(\n                data,\n                temperature=temperature,\n                calculate_entropy=calculate_entropy,\n                use_dynamic_bsz=self.config.use_dynamic_bsz,\n                micro_batch_size=micro_batch_size,\n                max_token_len=max_token_len,\n            )\n            metric_micro_batch = metric_micro_batch[\"output\"]\n            for metric in metric_micro_batch:\n                # Note that o[0] is metrics, o[1] is entropy, o[2] is response_mask\n                append_to_dict(metrics, metric)  # append the metric from this micro-batch to global metrics.\n\n            update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step()\n            data = {\"actor/grad_norm\": grad_norm}\n            append_to_dict(metrics, data)\n\n            if update_successful:\n                # allgather already execute in optimizer.step in new megatron\n                pass\n            else:\n                raise NotImplementedError\n            if self.use_torch_profiler and self.prof and self.prof.enable:\n                self.prof.step()\n        # add empty cache after each compute\n        if self.use_torch_profiler and self.prof and self.prof.enable:\n            self.prof.stop_and_save()\n            self.prof.stop_trace()\n        get_torch_device().empty_cache()\n        return metrics\n"
  },
  {
    "path": "siirl/engine/base_worker/__init__.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base.worker import Worker\nfrom .resouce_pool import RayResourcePool, RayClassWithInitArgs, WorkerGroup, get_random_string, sort_placement_group_by_node_ip\n\n__all__ = [\"Worker\", \"RayClassWithInitArgs\", \"RayResourcePool\", \"WorkerGroup\", \"get_random_string\", \"sort_placement_group_by_node_ip\"]\n"
  },
  {
    "path": "siirl/engine/base_worker/base/__init__.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/engine/base_worker/base/worker.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nthe class for Worker\n\"\"\"\n\nimport os\nimport socket\nfrom dataclasses import dataclass\nfrom typing import Dict\n\nimport ray\n\nfrom siirl.utils.extras.device import get_torch_device\n\n\n@dataclass\nclass DistRankInfo:\n    tp_rank: int\n    dp_rank: int\n    pp_rank: int\n    cp_rank: int\n\n\n@dataclass\nclass DistGlobalInfo:\n    tp_size: int\n    dp_size: int\n    pp_size: int\n    cp_size: int\n\n\nclass WorkerHelper:\n    def _get_node_ip(self):\n        def get_node_ip_by_sdk():\n            if os.getenv(\"WG_BACKEND\", None) == \"ray\":\n                import ray\n\n                return ray._private.services.get_node_ip_address()\n            else:\n                raise NotImplementedError(\"WG_BACKEND now just support ray mode.\")\n\n        host_ipv4 = os.getenv(\"MY_HOST_IP\", None)\n        host_ipv6 = os.getenv(\"MY_HOST_IPV6\", None)\n        host_ip_by_env = host_ipv4 or host_ipv6\n        host_ip_by_sdk = get_node_ip_by_sdk()\n\n        host_ip = host_ip_by_env or host_ip_by_sdk\n        return host_ip\n\n    def _get_free_port(self):\n        with socket.socket() as sock:\n            sock.bind((\"\", 0))\n            return sock.getsockname()[1]\n\n    def get_availale_master_addr_port(self):\n        return self._get_node_ip(), str(self._get_free_port())\n\n    def _get_pid(self):\n        return os.getpid()\n\n\n# we assume that in each WorkerGroup, there is a Master Worker\nclass Worker(WorkerHelper):\n    \"\"\"A distributed worker that handles initialization and configuration for distributed training.\n\n    This class manages worker initialization, configuration, and provides methods for executing\n    distributed operations. It handles communication settings, device configuration, and worker\n    metadata management.\n    \"\"\"\n\n    fused_worker_attr_name = \"fused_worker_dict\"\n\n    def __new__(cls, *args, **kwargs):\n        \"\"\"Create a new Worker instance with proper initialization based on environment settings.\"\"\"\n        instance = super().__new__(cls)\n\n        # note that here we use int to distinguish\n        disable_worker_init = int(os.environ.get(\"DISABLE_WORKER_INIT\", 0))\n        if disable_worker_init:\n            return instance\n\n        rank = os.environ.get(\"RANK\", None)\n        worker_group_prefix = os.environ.get(\"WG_PREFIX\", None)\n\n        # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init\n        if None not in [rank, worker_group_prefix] and \"ActorClass(\" not in cls.__name__:\n            instance._configure_before_init(f\"{worker_group_prefix}_register_center\", int(rank))\n\n        return instance\n\n    def _configure_before_init(self, register_center_name: str, rank: int):\n        \"\"\"Configure worker settings before initialization.\n\n        Args:\n            register_center_name (str):\n                Name of the register center Ray actor for worker coordination\n            rank (int):\n                Rank of the worker in the distributed setup\n        \"\"\"\n        assert isinstance(rank, int), f\"rank must be int, instead of {type(rank)}\"\n\n        if rank == 0:\n            master_addr, master_port = self.get_availale_master_addr_port()\n            rank_zero_info = {\n                \"MASTER_ADDR\": master_addr,\n                \"MASTER_PORT\": master_port,\n            }\n\n            if os.getenv(\"WG_BACKEND\", None) == \"ray\":\n                from siirl.engine.base_worker.register_center.register_center import create_worker_group_register_center\n\n                self.register_center = create_worker_group_register_center(name=register_center_name, info=rank_zero_info)\n\n            os.environ.update(rank_zero_info)\n        else:\n            self.register_center = ray.get_actor(register_center_name)\n\n        # set worker info for node affinity scheduling\n        ray.get(self.register_center.set_worker_info.remote(rank, ray.get_runtime_context().get_node_id()))\n\n    @classmethod\n    def env_keys(cls):\n        \"\"\"The keys of the environment variables that are used to configure the Worker.\"\"\"\n        return [\"WORLD_SIZE\", \"RANK\", \"LOCAL_WORLD_SIZE\", \"LOCAL_RANK\", \"MASTER_ADDR\", \"MASTER_PORT\", \"CUDA_VISIBLE_DEVICES\"]\n\n    def __init__(self, cuda_visible_devices=None) -> None:\n        \"\"\"Initialize the worker with environment settings and device configuration.\n\n        Args:\n            cuda_visible_devices (str, optional):\n                CUDA visible devices configuration. Defaults to None.\n        \"\"\"\n        # construct a meta from environment variable. Note that the import must be inside the class because it is executed remotely\n        import os\n\n        self._setup_env_cuda_visible_devices()\n\n        world_size = int(os.environ[\"WORLD_SIZE\"])\n        rank = int(os.environ[\"RANK\"])\n        self._rank = rank\n        self._world_size = world_size\n\n        master_addr = os.environ[\"MASTER_ADDR\"]\n        master_port = os.environ[\"MASTER_PORT\"]\n\n        local_world_size = int(os.getenv(\"LOCAL_WORLD_SIZE\", \"1\"))\n        local_rank = int(os.getenv(\"LOCAL_RANK\", \"0\"))\n\n        store = {\n            \"_world_size\": world_size,\n            \"_rank\": rank,\n            \"_local_world_size\": local_world_size,\n            \"_local_rank\": local_rank,\n            \"_master_addr\": master_addr,\n            \"_master_port\": master_port,\n        }\n        if cuda_visible_devices is not None:\n            store[\"_cuda_visible_devices\"] = cuda_visible_devices\n\n        self._configure_with_store(store=store)\n\n        self.fused_worker_dict = {}\n\n    def get_fused_worker_by_name(self, worker_name: str):\n        \"\"\"Get a fused worker by its name.\n\n        Args:\n            worker_name (str):\n                Name of the worker to retrieve\n        \"\"\"\n        return self.fused_worker_dict.get(worker_name, None)\n\n    def _setup_env_cuda_visible_devices(self):\n        from siirl.utils.extras.ray_utils import ray_noset_visible_devices\n\n        is_ray_noset_visible_devices = ray_noset_visible_devices()\n\n        # Prevent use of clashing `{CUDA/HIP/ROCR}_VISIBLE_DEVICES``\n        rocr_val = os.environ.get(\"ROCR_VISIBLE_DEVICES\", None)\n        hip_val = os.environ.get(\"HIP_VISIBLE_DEVICES\", None)\n        cuda_val = os.environ.get(\"CUDA_VISIBLE_DEVICES\", None)\n        if hip_val:\n            # Switch the use of HIP_VISIBLE_DEVICES to CUDA_VISIBLE_DEVICES for consistency.\n            # Make sure that the HIP_VISIBLE_DEVICES is set to the same value as CUDA_VISIBLE_DEVICES\n            # at this point.\n            val = os.environ.pop(\"HIP_VISIBLE_DEVICES\")\n            hip_val = None\n            if cuda_val:\n                assert val == cuda_val, f\"Please use the same HIP_VISIBLE_DEVICES or CUDA_VISIBLE_DEVICES, inconsistant values found: {val} and {cuda_val}.\"\n            else:\n                cuda_val = val\n                os.environ[\"CUDA_VISIBLE_DEVICES\"] = val\n\n        if rocr_val:\n            # You must take care if both HIP/CUDA and ROCR env vars are set as they have\n            # different meanings. Both env vars accept either a list of ints or a\n            # list of UUIDs. The ROCR env var is processed first which then reduces\n            # the number of GPUs that HIP can select from.\n            # https://github.com/pytorch/pytorch/pull/144026\n            # To avoid the complexity of this, we simply gives out error if both are set\n            # (Also to keep consistency with ray's practice with 2.45.0).\n            # Otherwise, we will set ROCR_VISIBLE_DEVICES to CUDA_VISIBLE_DEVICES\n            # and remove ROCR_VISIBLE_DEVICES.\n            if cuda_val:\n                raise ValueError(\"Please don't set ROCR_VISIBLE_DEVICES when HIP/CUDA_VISIBLE_DEVICES is set.\")\n\n            cuda_val = os.environ.pop(\"ROCR_VISIBLE_DEVICES\")\n            os.environ[\"CUDA_VISIBLE_DEVICES\"] = cuda_val\n            rocr_val = None\n\n        if is_ray_noset_visible_devices:\n            # NOTE: Ray will automatically set the *_VISIBLE_DEVICES\n            # environment variable for each actor, unless\n            # RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is set,\n            # so we need to set local rank when the flag is set.\n            local_rank = os.environ.get(\"RAY_LOCAL_RANK\")\n            os.environ[\"LOCAL_RANK\"] = local_rank\n            get_torch_device().set_device(int(local_rank))\n\n    def _configure_with_store(self, store: Dict):\n        \"\"\"\n        This function should only be called inside by WorkerGroup\n        \"\"\"\n        store_env_dict = {f\"_{key.lower()}\": store.get(f\"_{key.lower()}\", None) for key in type(self).env_keys()}\n        self.__dict__.update(store_env_dict)  # this is hacky\n        # print(f\"__dict__: {self.__dict__}\")\n        for key in type(self).env_keys():\n            val = self.__dict__.get(f\"_{key.lower()}\", None)\n            if val is not None:\n                # print(f\"set {key} to {val}\")\n                os.environ[key] = str(val)\n        os.environ[\"REDIS_STORE_SERVER_HOST\"] = str(self._master_addr).replace(\"[\", \"\").replace(\"]\", \"\") if self._master_addr else \"\"\n\n    def get_master_addr_port(self):\n        \"\"\"Get the master address and port for distributed communication.\"\"\"\n        return self._master_addr, self._master_port\n\n    def get_cuda_visible_devices(self):\n        \"\"\"Get the CUDA visible devices configuration.\"\"\"\n        import os\n\n        cuda_visible_devices = os.environ.get(\"CUDA_VISIBLE_DEVICES\", \"not set\")\n        return cuda_visible_devices\n\n    @property\n    def world_size(self):\n        \"\"\"Get the total number of workers in the distributed setup.\"\"\"\n        return self._world_size\n\n    @property\n    def rank(self):\n        \"\"\"Get the rank of this worker in the distributed setup.\"\"\"\n        return self._rank\n\n    def execute_with_func_generator(self, func, *args, **kwargs):\n        \"\"\"Execute a function with function generator dispatch mode.\n\n        Args:\n            func:\n                Function to execute\n            *args:\n                Positional arguments for the function\n            **kwargs:\n                Keyword arguments for the function\n        \"\"\"\n        ret_proto = func(self, *args, **kwargs)\n        return ret_proto\n\n    def execute_func_rank_zero(self, func, *args, **kwargs):\n        \"\"\"Execute a function in rank zero execution mode.\n\n        Args:\n            func:\n                Function to execute\n            *args:\n                Positional arguments for the function\n            **kwargs:\n                Keyword arguments for the function\n        \"\"\"\n        result = func(*args, **kwargs)\n        return result\n"
  },
  {
    "path": "siirl/engine/base_worker/megatron/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/engine/base_worker/megatron/npu_mbridge_patch.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom mbridge.core.util import unwrap_model\nfrom mbridge.core import Bridge\nimport torch\n\n\ndef load_weights_patch(\n    self,\n    models: list[torch.nn.Module],\n    weights_path: str,\n    memory_efficient: bool = False,\n) -> None:\n    \"\"\"\n    Load weights from a Hugging Face model into a Megatron-Core model.\n\n    Args:\n        models: List of model instances, supporting VPP (Virtual Pipeline Parallelism)\n        weights_path: Path to the weights file or Hugging Face model identifier\n    \"\"\"\n    self.safetensor_io = self._get_safetensor_io(weights_path)\n\n    for i, model in enumerate(models):\n        # map local weight names to global weight names\n        local_to_global_map = self._weight_name_mapping_mcore_local_to_global(model)\n        # map local weight names to huggingface weight names\n        local_to_hf_map = {\n            k: self._weight_name_mapping_mcore_to_hf(v)\n            for k, v in local_to_global_map.items()\n            if \"_extra_state\" not in k\n        }\n        # only tp_rank0/etp_rank0 load from disk, others load from tp_rank0/etp_rank0\n        to_load_from_disk = []\n        for local_name, hf_names in local_to_hf_map.items():\n            if \".mlp.experts.linear_fc\" in local_name:\n                if self.mpu.etp_rank == 0:\n                    to_load_from_disk.extend(hf_names)\n            else:\n                if self.mpu.tp_rank == 0:\n                    to_load_from_disk.extend(hf_names)\n                else:\n                    # special case for lm_head.weight\n                    # if make value model, every tp rank will load lm_head.weight\n                    if \"lm_head.weight\" in hf_names:\n                        to_load_from_disk.extend(hf_names)\n\n        # load huggingface weights\n        if not memory_efficient:\n            hf_weights_map = self.safetensor_io.load_some_hf_weight(\n                to_load_from_disk\n            )\n        model = unwrap_model(model)\n        # Some weights are in named_parameters but not in state_dict.\n        with torch.no_grad():\n            for local_name, hf_names in local_to_hf_map.items():\n                # Maybe a bug in torch_npu. Some weights are registered in named_parameters but not in state_dict.\n                if model.state_dict().get(local_name, None) is None:\n                    param = dict(model.named_parameters())[local_name]\n                else: \n                    param = model.state_dict()[local_name]\n                # hf format to mcore format\n                if set(to_load_from_disk) & set(hf_names):\n                    if not memory_efficient:\n                        hf_weights = [hf_weights_map[x] for x in hf_names]\n                    else:\n                        hf_weights = [\n                            self.safetensor_io.load_one_hf_weight(x) for x in hf_names\n                        ]\n                    mcore_weight = self._weight_to_mcore_format(local_name, hf_weights)\n                else:\n                    mcore_weight = None\n                if hf_names[0] == \"lm_head.weight\":\n                    if param.shape[0] == 1 and mcore_weight.shape[0] != 1:\n                        # skip lm_head.weight when the model is a value model\n                        continue\n\n                param_to_load = torch.empty_like(param)\n                if \".mlp.experts.linear_fc\" in local_name:\n                    # split mcore weights across etp\n                    if self.mpu.etp_rank == 0:\n                        mcore_weights_tp_split = self._weight_split_across_tp(\n                            local_name, mcore_weight, param, self.mpu.etp_size\n                        )\n                        mcore_weights_tp_split = list(mcore_weights_tp_split)\n                        mcore_weights_tp_split = [\n                            t.to(param.device, dtype=param.dtype).contiguous()\n                            for t in mcore_weights_tp_split\n                        ]\n                    else:\n                        mcore_weights_tp_split = None\n                    torch.distributed.scatter(\n                        param_to_load,\n                        mcore_weights_tp_split,\n                        src=torch.distributed.get_global_rank(self.mpu.etp_group, 0),\n                        group=self.mpu.etp_group,\n                    )\n                else:\n                    # split mcore weights across tp\n                    if self.mpu.tp_rank == 0:\n                        mcore_weights_tp_split = self._weight_split_across_tp(\n                            local_name, mcore_weight, param, self.mpu.tp_size\n                        )\n                        mcore_weights_tp_split = list(mcore_weights_tp_split)\n                        mcore_weights_tp_split = [\n                            t.to(param.device, dtype=param.dtype).contiguous()\n                            for t in mcore_weights_tp_split\n                        ]\n                    else:\n                        mcore_weights_tp_split = None\n                    torch.distributed.scatter(\n                        param_to_load,\n                        mcore_weights_tp_split,\n                        src=torch.distributed.get_global_rank(self.mpu.tp_group, 0),\n                        group=self.mpu.tp_group,\n                    )\n                # load\n                param.copy_(param_to_load.detach())\n\ndef _weight_name_mapping_mcore_local_to_global_patch(\n    self, model: torch.nn.Module, consider_ep: bool = True\n) -> dict[str, str]:\n    \"\"\"\n    Map local weight names to global weight names, supporting VPP and EP.\n\n    Args:\n        model: The model instance\n\n    Returns:\n        dict: Mapping from local weight names to global weight names\n    \"\"\"\n    # vpp\n    local_layer_to_global_layer = {}\n    model = unwrap_model(model)\n    if hasattr(model, \"decoder\"):\n        for idx, layer in enumerate(model.decoder.layers):\n            local_layer_to_global_layer[idx] = layer.layer_number - 1\n    # Maybe a bug in torch_npu. Some weights are registered in named_parameters but not in state_dict.\n    all_named_param_names = [\n        k for k,_ in model.named_parameters() if \"_extra_state\" not in k\n    ]\n    all_state_dict_keys = [\n        k for k in model.state_dict().keys() if \"_extra_state\" in k\n    ]\n    all_param_names = list(dict.fromkeys(all_named_param_names + all_state_dict_keys))\n    ret = {}\n    for param_name in all_param_names:\n        keyword = \"decoder.layers.\"\n        if keyword in param_name:\n            layer_idx = int(param_name.split(keyword)[1].split(\".\")[0])\n            global_layer_idx = local_layer_to_global_layer[layer_idx]\n            ret[param_name] = param_name.replace(\n                f\"layers.{layer_idx}.\", f\"layers.{global_layer_idx}.\"\n            )\n        else:\n            ret[param_name] = param_name\n\n    # ep\n    if self.mpu.ep_size > 1 and consider_ep:\n        num_experts = self.config.num_moe_experts\n        num_experts_per_rank = num_experts // self.mpu.ep_size\n        local_expert_to_global_expert = {\n            i: i + num_experts_per_rank * self.mpu.ep_rank\n            for i in range(num_experts_per_rank)\n        }\n        for k in ret.keys():\n            v = ret[k]\n            if \".mlp.experts.linear_fc\" in v:\n                name_prefix, local_expert_id = v.split(\".weight\")\n                global_expert_idx = local_expert_to_global_expert[\n                    int(local_expert_id)\n                ]\n                ret[k] = f\"{name_prefix}.weight{global_expert_idx}\"\n\n    return ret\n\nBridge.load_weights = load_weights_patch\nBridge._weight_name_mapping_mcore_local_to_global = _weight_name_mapping_mcore_local_to_global_patch\n"
  },
  {
    "path": "siirl/engine/base_worker/megatron/worker.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2025, Infrawaves. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom siirl.engine.base_worker.base.worker import DistGlobalInfo, DistRankInfo, Worker\nfrom siirl.params import ActorRolloutRefArguments\nfrom siirl.utils.extras.device import is_npu_available\n\nclass MegatronWorker(Worker):\n    def __init__(self, cuda_visible_devices=None) -> None:\n        super().__init__(cuda_visible_devices)\n\n    def get_megatron_global_info(self):\n        from megatron.core import parallel_state as mpu\n\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        dp_size = mpu.get_data_parallel_world_size()\n        pp_size = mpu.get_pipeline_model_parallel_world_size()\n        cp_size = mpu.get_context_parallel_world_size()\n        info = DistGlobalInfo(tp_size=tp_size, dp_size=dp_size, pp_size=pp_size, cp_size=cp_size)\n        return info\n\n    def get_megatron_rank_info(self):\n        from megatron.core import parallel_state as mpu\n\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        dp_rank = mpu.get_data_parallel_rank()\n        pp_rank = mpu.get_pipeline_model_parallel_rank()\n        cp_rank = mpu.get_context_parallel_rank()\n        info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank, cp_rank=cp_rank)\n        return info\n\n    # TODO(Ping Zhang): Seperate this function for rollout\n    def _init_hf_config_and_tf_config(\n        self,\n        model_path,\n        tokenizer_or_path,\n        dtype,\n        override_model_config,\n        override_transformer_config,\n        trust_remote_code=False,\n        use_mbridge=False,\n    ):\n        from transformers import AutoConfig\n\n        from siirl.models.mcore import hf_to_mcore_config\n        from siirl.models.loader import load_tokenizer\n        from siirl.utils.extras.fs import copy_to_local\n        from siirl.utils.model_utils.model import update_model_config\n\n        # Step 1: initialize the tokenizer\n        self.local_path = copy_to_local(model_path)\n        if tokenizer_or_path is None:\n            tokenizer_processor = load_tokenizer(path=self.local_path)\n            self.tokenizer = tokenizer_processor[\"tokenizer\"]\n            self.processor = tokenizer_processor[\"processor\"]\n        elif isinstance(tokenizer_or_path, str):\n            tokenizer_processor = load_tokenizer(path=copy_to_local(tokenizer_or_path))\n            self.tokenizer = tokenizer_processor[\"tokenizer\"]\n            self.processor = tokenizer_processor[\"processor\"]\n        else:\n            self.tokenizer = tokenizer_or_path\n            self.processor = tokenizer_or_path\n\n        # Step 2: get the hf\n        hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code)\n\n        # Step 3: override the hf config\n        override_config_kwargs = {\n            \"bos_token_id\": self.tokenizer.bos_token_id,\n            \"eos_token_id\": self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.tokenizer.pad_token_id,\n        }\n        override_config_kwargs.update(override_model_config.get(\"model_config\", {}))\n        self.share_embeddings_and_output_weights = getattr(hf_config, \"tie_word_embeddings\", False)\n        update_model_config(hf_config, override_config_kwargs=override_config_kwargs)\n        self.architectures = getattr(hf_config, \"architectures\", None)\n        if self.rank == 0:\n            print(f\"Model config after override: {hf_config}\")\n        tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config)\n\n        def add_optimization_config_to_tf_config(tf_config):\n            # add optimization config to tf_config, e.g. checkpointing\n            if self.config.model.enable_gradient_checkpointing:\n                gradient_checkpointing_cfg = dict(self.config.model.gradient_checkpointing_kwargs)\n                tf_config.recompute_method = gradient_checkpointing_cfg.get(\"activations_checkpoint_method\", \"uniform\")\n                tf_config.recompute_granularity = gradient_checkpointing_cfg.get(\"activations_checkpoint_granularity\", None)\n                tf_config.recompute_num_layers = gradient_checkpointing_cfg.get(\"activations_checkpoint_num_layers\", 1)\n            \n            if isinstance(self.config, ActorRolloutRefArguments):\n                megatron_config = self.config.actor.megatron\n            else:\n                megatron_config = self.config.megatron\n\n            if megatron_config:\n                if extra := megatron_config.extra:\n                    for k, v in extra.items():\n                        setattr(tf_config, k, v)\n\n        add_optimization_config_to_tf_config(tf_config)\n\n        if use_mbridge:\n            if is_npu_available:\n                if self.rank == 0:\n                    print(f\"Patching mbridge for NPU ......\")\n                from . import npu_mbridge_patch\n            from siirl.models.mcore.mbridge import AutoBridge\n\n            bridge = AutoBridge.from_config(hf_config)\n            bridge.set_extra_args(**override_transformer_config)\n            tf_config = bridge.config\n            self.bridge = bridge\n        else:\n            self.bridge = None\n\n        print(f\"TF config: {tf_config}\")\n        self.hf_config = hf_config\n        self.tf_config = tf_config\n"
  },
  {
    "path": "siirl/engine/base_worker/register_center/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/engine/base_worker/register_center/register_center.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Dict\n\nimport ray\n\n\n@ray.remote\nclass WorkerGroupRegisterCenter:\n    def __init__(self, rank_zero_info):\n        self.rank_zero_info = rank_zero_info\n        # rank -> node_id\n        self.workers_info: Dict[int, str] = {}\n\n    def get_rank_zero_info(self):\n        return self.rank_zero_info\n\n    def set_worker_info(self, rank, node_id) -> None:\n        self.workers_info[rank] = node_id\n\n    def get_worker_info(self) -> Dict[int, str]:\n        return self.workers_info\n\n\ndef create_worker_group_register_center(name, info):\n    return WorkerGroupRegisterCenter.options(name=name, get_if_exists=True).remote(info)\n"
  },
  {
    "path": "siirl/engine/base_worker/resouce_pool.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport threading\nfrom typing import Any, Dict, List, Optional, Tuple\n\nfrom loguru import logger\nimport ray\nfrom ray.util.placement_group import PlacementGroup, placement_group\nfrom ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy\n\n\nclass ResourcePool:\n    \"\"\"\n    Manages a pool of resources across multiple nodes, tracking process counts and GPU allocations.\n    The class provides methods to calculate world size, local world sizes, and local ranks\n    across all nodes in the pool.\n    \"\"\"\n\n    def __init__(self, process_on_nodes=None, max_colocate_count: int = 10, n_gpus_per_node=8) -> None:\n        \"\"\"Initialize the ResourcePool with node processes and GPU configuration.\n\n        Args:\n            process_on_nodes (List[int], optional): List of process counts per node. Defaults to empty list.\n            max_colocate_count (int, optional): Maximum number of processes that can be colocated. Defaults to 10.\n            n_gpus_per_node (int, optional): Number of GPUs available per node. Defaults to 8.\n        \"\"\"\n        if process_on_nodes is None:\n            process_on_nodes = []\n        self._store = process_on_nodes\n        self.max_colocate_count = max_colocate_count\n        self.n_gpus_per_node = n_gpus_per_node  # this is left for future huawei GPU that contains 16 GPUs per node\n\n    def add_node(self, process_count):\n        self._store.append(process_count)\n\n    @property\n    def world_size(self):\n        \"\"\"Total number of processes across all nodes in the pool.\"\"\"\n        return sum(self._store)\n\n    def __call__(self) -> Any:\n        return self._store\n\n    @property\n    def store(self):\n        return self._store\n\n    def local_world_size_list(self) -> List[int]:\n        \"\"\"Returns a flat list where each process has its local world size.\"\"\"\n        nested_local_world_size_list = [[local_world_size for _ in range(local_world_size)] for local_world_size in self._store]\n        return [item for row in nested_local_world_size_list for item in row]\n\n    def local_rank_list(self) -> List[int]:\n        \"\"\"Returns a flat list of local ranks for all processes across all nodes.\"\"\"\n        nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store]\n        return [item for row in nested_local_rank_list for item in row]\n\n\nclass ClassWithInitArgs:\n    \"\"\"\n    Wrapper class that stores constructor arguments for deferred instantiation.\n    This class is particularly useful for remote class instantiation where\n    the actual construction needs to happen at a different time or location.\n    \"\"\"\n\n    def __init__(self, cls, *args, **kwargs) -> None:\n        \"\"\"Initialize the ClassWithInitArgs instance.\n\n        Args:\n            cls: The class to be instantiated later\n            *args: Positional arguments for the class constructor\n            **kwargs: Keyword arguments for the class constructor\n        \"\"\"\n        self.cls = cls\n        self.args = args\n        self.kwargs = kwargs\n\n        self.fused_worker_used = False\n\n    def __call__(self) -> Any:\n        \"\"\"Instantiate the stored class with the stored arguments.\"\"\"\n        return self.cls(*self.args, **self.kwargs)\n\n\nclass WorkerGroup:\n    \"\"\"\n    Base class for managing a group of workers in a distributed system.\n    The class provides methods for worker management, aliveness checking, and method binding.\n    \"\"\"\n\n    fused_worker_execute_fn_name = \"_fuw_execute\"\n\n    def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:\n        self._is_init_with_detached_workers = resource_pool is None\n\n        self.fused_worker_used = False\n\n        if resource_pool is not None:\n            # handle the case when WorkGroup is attached to an existing one\n            self._procecss_dispatch_config = resource_pool()\n        else:\n            self._procecss_dispatch_config = None\n\n        self._workers = []\n        self._worker_names = []\n\n        self._master_addr = None\n        self._master_port = None\n\n        self._checker_thread: threading.Thread = None\n\n    def _is_worker_alive(self, worker):\n        \"\"\"Check if a worker is alive. Must be implemented by derived classes.\"\"\"\n        raise NotImplementedError(\"WorkerGroup._is_worker_alive called, should be implemented in derived class.\")\n\n    @property\n    def world_size(self):\n        \"\"\"Number of workers in the group.\"\"\"\n        return len(self._workers)\n\n\ndef get_random_string(length: int) -> str:\n    import random\n    import string\n\n    letters_digits = string.ascii_letters + string.digits\n    return \"\".join(random.choice(letters_digits) for _ in range(length))\n\n\ndef sort_placement_group_by_node_ip(pgs: List[PlacementGroup]) -> List[PlacementGroup]:\n    \"\"\"\n    Sort the placement groups by node ip, all bundles in a single placement group should be on the same node.\n\n    FSDPCheckpointManager saves sharded model states and optimizer states in local storage, which requires RANK\n    to be consistent across nodes when resume from checkpoint.\n\n    With this function, if there's only one resource pool and there's no node change, RANK should be consistent\n    across nodes in multiple ray jobs, even if the whole ray cluster is restarted.\n    \"\"\"\n    node_ip = {node[\"NodeID\"]: node[\"NodeManagerAddress\"] for node in ray.nodes()}\n    pg_ip = {}\n    for pg in pgs:\n        specs = ray._private.state.state.placement_group_table(pg.id)\n        # all bunles should be on the same node\n        node_id = specs[\"bundles_to_node_id\"][0]\n        pg_ip[pg.id] = node_ip[node_id]\n    return sorted(pgs, key=lambda pg: pg_ip[pg.id])\n\n\nclass RayResourcePool(ResourcePool):\n    def __init__(\n        self,\n        process_on_nodes: Optional[List[int]] = None,\n        use_gpu: bool = True,\n        name_prefix: str = None,\n        max_colocate_count: int = 10,\n        detached=False,\n        accelerator_type: Optional[str] = None,\n    ) -> None:\n        super().__init__(process_on_nodes, max_colocate_count)\n        self.use_gpu = use_gpu\n        # print(f\"in RayProcessDispatchConfiguration: name_prefix = {name_prefix}\")\n        self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix\n        self.pgs = None\n        self.detached = detached\n        self.accelerator_type = accelerator_type\n\n    def get_placement_groups(self, strategy=\"STRICT_PACK\", name=None, device_name=\"cuda\"):\n        if self.pgs is not None:\n            return self.pgs\n\n        pg_name_prefix = name if name else f\"{self.name_prefix}siirl_group_{'_'.join([str(count) for count in self._store])}:\"\n        # print(f\"pg_name_prefix = {pg_name_prefix}\")\n        if device_name == \"npu\":\n            device_name = \"NPU\"\n        elif device_name == \"cuda\":\n            device_name = \"GPU\"\n\n        bundle = {\"CPU\": self.max_colocate_count}\n        if self.use_gpu:\n            bundle[device_name] = 1\n            if self.accelerator_type is not None:\n                bundle[self.accelerator_type] = 1e-4\n        pg_scheme = [[bundle.copy() for _ in range(process_count)] for process_count in self._store]\n\n        lifetime = \"detached\" if self.detached else None\n\n        pgs = [placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime) for idx, bundles in enumerate(pg_scheme)]\n\n        ray.get([pg.ready() for pg in pgs])\n\n        self.pgs = pgs\n        return pgs\n\n\ndef extract_pg_from_exist(resource_pools: Dict[str, RayResourcePool], src_role_names: List[str], resource_pool: RayResourcePool) -> List:\n    src_pgs = [pg for role_name, resource_pool in resource_pools.items() for pg in resource_pool.get_placement_groups() if role_name in src_role_names]\n\n    sorted_src_pgs = sorted(src_pgs, key=lambda pg: pg.bundle_count, reverse=True)\n    sorted_process_on_nodes = sorted([(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True)\n\n    unsorted_pgs: List[Tuple[int, PlacementGroup]] = []\n    searching_idx = 0\n    for request_process, original_idx in sorted_process_on_nodes:\n        assert searching_idx < len(sorted_src_pgs), f\"no enough nodes for request: searching {searching_idx} th node\"\n        assert request_process <= sorted_src_pgs[searching_idx].bundle_count, f\"requesting {request_process} processes, bundle count cannot satisfy\"\n        unsorted_pgs.append((original_idx, sorted_src_pgs[searching_idx]))\n        searching_idx += 1\n\n    return [pg for _, pg in sorted(unsorted_pgs)]\n\n\ndef merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool:\n    assert rp1.use_gpu == rp2.use_gpu, \"Both RayResourcePool must either use_gpu or not\"\n    assert rp1.max_colocate_count == rp2.max_colocate_count, \"Both RayResourcePool must has the same max_colocate_count\"\n    assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, \"Both RayResourcePool must has the same n_gpus_per_node\"\n    assert rp1.detached == rp2.detached, \"Detached ResourcePool cannot be merged with non-detached ResourcePool\"\n\n    new_store = rp1.store + rp2.store\n\n    merged = type(rp1)(new_store, rp1.use_gpu, f\"{rp1.name_prefix}_{rp2.name_prefix}\")\n    merged.pgs = rp1.get_placement_groups() + rp2.get_placement_groups()\n\n    return merged\n\n\nclass RayClassWithInitArgs(ClassWithInitArgs):\n    \"\"\"A wrapper class for Ray actors with initialization arguments.\n\n    This class extends ClassWithInitArgs to provide additional functionality for\n    configuring and creating Ray actors with specific resource requirements and\n    scheduling strategies.\n    \"\"\"\n\n    def __init__(self, cls, *args, **kwargs) -> None:\n        # self._options = kwargs.pop('options', dict())\n        super().__init__(cls, *args, **kwargs)\n        self._options = {}\n        self._additional_resource = {}\n\n    def set_additional_resource(self, additional_resource):\n        \"\"\"Set additional resource requirements for the actor.\n\n        Args:\n            additional_resource: Dictionary specifying additional resource requirements\n        \"\"\"\n        self._additional_resource = additional_resource\n\n    def update_options(self, options: Dict):\n        \"\"\"Update the Ray actor creation options.\n\n        Args:\n            options: Dictionary of options to update\n        \"\"\"\n        self._options.update(options)\n\n    def __call__(self, placement_group, placement_group_bundle_idx, use_gpu: bool = True, num_gpus=1, sharing_with=None, rank=0, device_name=\"cuda\") -> Any:\n        \"\"\"Create and return a Ray actor with the configured options.\n\n        Args:\n            placement_group: Ray placement group for scheduling\n            placement_group_bundle_idx: Index of the bundle in the placement group\n            use_gpu: Whether to use GPU resources\n            num_gpus: Number of GPUs to allocate\n            sharing_with: Actor to share resources with\n\n        Returns:\n            A Ray actor handle with the configured options\n        \"\"\"\n        # Do not mutate self.kwargs, as this object is shared across all ranks.\n        local_kwargs = self.kwargs.copy()\n        local_kwargs.pop(\"device_name\", \"cuda\")\n        \n        if sharing_with is not None:\n            target_node_id = ray.get(sharing_with.get_node_id.remote())\n            cuda_visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote())\n            options = {\"scheduling_strategy\": NodeAffinitySchedulingStrategy(node_id=target_node_id, soft=False)}\n            return self.cls.options(**options).remote(*self.args, cuda_visible_devices=cuda_visible_devices, **local_kwargs)\n\n        options = {\"scheduling_strategy\": PlacementGroupSchedulingStrategy(placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_idx)}\n        options.update(self._options)\n\n        if use_gpu and device_name == \"cuda\":\n            options[\"num_gpus\"] = num_gpus\n        if use_gpu and device_name == \"npu\":\n            options[\"resources\"] = {\"NPU\": num_gpus}\n\n        if len(self._additional_resource) > 1:\n            for k, v in self._additional_resource.items():\n                options[k] = v\n\n        # print(\"cls:\", self.cls)\n        # print(\"args: \", self.args)\n        # print(\"kwargs: \", self.kwargs)\n        return self.cls.options(**options).remote(*self.args, **local_kwargs)"
  },
  {
    "path": "siirl/engine/critic/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base import BasePPOCritic\nfrom .dp_critic import DataParallelPPOCritic\n\n__all__ = [\"BasePPOCritic\", \"DataParallelPPOCritic\"]\n"
  },
  {
    "path": "siirl/engine/critic/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nBase class for a critic\n\"\"\"\n\nfrom abc import ABC, abstractmethod\n\nimport torch\n\nfrom tensordict import TensorDict\n\n__all__ = [\"BasePPOCritic\"]\n\n\nclass BasePPOCritic(ABC):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n\n    @abstractmethod\n    def compute_values(self, data: TensorDict) -> torch.Tensor:\n        \"\"\"Compute values\"\"\"\n        pass\n\n    @abstractmethod\n    def update_critic(self, data: TensorDict):\n        \"\"\"Update the critic\"\"\"\n        pass\n"
  },
  {
    "path": "siirl/engine/critic/dp_critic.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nImplement a multiprocess PPOCritic\n\"\"\"\n\nimport itertools\n\nimport torch\nimport torch.distributed\nfrom torch import nn, optim\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom tensordict import TensorDict\nfrom siirl.dag_worker import core_algos\nfrom siirl.utils.extras.device import get_device_id, get_device_name, is_cuda_available, is_npu_available\nfrom siirl.utils.model_utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_\nfrom siirl.utils.extras.py_functional import append_to_dict\nfrom siirl.utils.model_utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\nfrom siirl.utils.model_utils.torch_functional import masked_mean\nfrom siirl.utils.model_utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs\nfrom siirl.engine.critic import BasePPOCritic\nfrom siirl.params.model_args import CriticArguments\n\nif is_cuda_available:\n    from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input\nelif is_npu_available:\n    from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input\n\n\nclass DataParallelPPOCritic(BasePPOCritic):\n    def __init__(self, config: CriticArguments, critic_module: nn.Module, critic_optimizer: optim.Optimizer):\n        super().__init__(config=config)\n        self.critic_module = critic_module\n        self.critic_optimizer = critic_optimizer\n        self.use_remove_padding = self.config.model.use_remove_padding\n        print(f\"Critic use_remove_padding={self.use_remove_padding}\")\n\n        self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size\n        self.device_name = get_device_name()\n\n    def _forward_micro_batch(self, micro_batch):\n        response_length = micro_batch[\"responses\"].size(-1)\n        multi_modal_inputs = {}\n        if \"multi_modal_inputs\" in micro_batch.keys():\n            for key in micro_batch[\"multi_modal_inputs\"][0].keys():\n                multi_modal_inputs[key] = torch.cat([inputs[key] for inputs in micro_batch[\"multi_modal_inputs\"]], dim=0)\n\n        with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16):\n            input_ids = micro_batch[\"input_ids\"]\n            batch, seqlen = input_ids.shape\n            attention_mask = micro_batch[\"attention_mask\"]\n            position_ids = micro_batch[\"position_ids\"]\n            if position_ids.dim() == 3:  # qwen2vl mrope\n                position_ids = position_ids.transpose(0, 1)\n\n            if self.use_remove_padding:\n                input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask)  # input_ids_rmpad (total_nnz, ...)\n                input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n                # unpad the position_ids to align the rotary\n                if position_ids.dim() == 3:\n                    position_ids_rmpad = index_first_axis(rearrange(position_ids, \"c b s ... -> (b s) c ...\"), indices).transpose(0, 1).unsqueeze(1)  # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)\n                else:\n                    position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices).transpose(0, 1)\n\n                # pad and slice the inputs if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size)\n\n                # only pass input_ids and position_ids to enable flash_attn_varlen\n                output = self.critic_module(\n                    input_ids=input_ids_rmpad,\n                    attention_mask=None,\n                    position_ids=position_ids_rmpad,\n                    **multi_modal_inputs,\n                    use_cache=False,\n                )  # prevent model thinks we are generating\n                values_rmpad = output.logits\n                values_rmpad = values_rmpad.squeeze(0)  # (total_nnz)\n\n                # gather output if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    values_rmpad = gather_outpus_and_unpad(values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size)\n\n                # pad it back\n                values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1)\n                values = values[:, -response_length - 1 : -1]\n            else:\n                output = self.critic_module(\n                    input_ids=input_ids,\n                    attention_mask=attention_mask,\n                    position_ids=position_ids,\n                    **multi_modal_inputs,\n                    use_cache=False,\n                )  # prevent model thinks we are generating\n                values = output.logits\n                values = values[:, -response_length - 1 : -1].squeeze(-1)\n            return values\n\n    def _optimizer_step(self):\n        assert self.config.grad_clip is not None\n\n        if isinstance(self.critic_module, FSDP):\n            grad_norm = self.critic_module.clip_grad_norm_(self.config.grad_clip)\n        elif isinstance(self.critic_module, FSDPModule):\n            grad_norm = fsdp2_clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip)\n        else:\n            grad_norm = torch.nn.utils.clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip)\n\n        # if grad_norm is not finite, skip the update\n        if not torch.isfinite(grad_norm):\n            print(f\"WARN: grad_norm is not finite: {grad_norm}\")\n            self.critic_optimizer.zero_grad()\n        else:\n            self.critic_optimizer.step()\n        return grad_norm\n\n    # @GPUMemoryLogger(role=\"dp critic\", logger=logger)\n    def compute_values(self, data: TensorDict) -> torch.Tensor:\n        self.critic_module.eval()\n        micro_batch_size = data[\"micro_batch_size\"]\n        select_keys = [\"responses\", \"input_ids\", \"attention_mask\", \"position_ids\"]\n        use_dynamic_bsz = data[\"use_dynamic_bsz\"]\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.keys()\n\n        if has_multi_modal_inputs:\n            num_micro_batches = data.batch_size[0] // micro_batch_size\n            select_keys.append(\"multi_modal_inputs\")\n            micro_batches = data.select(*select_keys).chunk(num_micro_batches)\n        elif use_dynamic_bsz:\n            # split using dynamic bsz\n            batch = data.select(*select_keys)\n            max_token_len = data[\"max_token_len\"] * self.ulysses_sequence_parallel_size\n            micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)\n        else:\n            batch = data.select(*select_keys)\n            micro_batches = batch.split(micro_batch_size)\n\n        values_lst = []\n        for micro_batch in micro_batches:\n            with torch.no_grad():\n                values = self._forward_micro_batch(micro_batch)\n            values_lst.append(values)\n        values = torch.concat(values_lst, dim=0)\n\n        if use_dynamic_bsz:\n            indices = list(itertools.chain.from_iterable(indices))\n            assert len(indices) == values.size(0), f\"{len(indices)} vs. {values.size()}\"\n            revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n            values = values[revert_indices]\n\n        response_mask = data[\"response_mask\"]\n\n        values = values * response_mask  # Only action tokens have values\n        return values\n\n    # @GPUMemoryLogger(role=\"dp critic\", logger=logger)\n    def update_critic(self, data: TensorDict):\n        # make sure we are in training mode\n        self.critic_module.train()\n        metrics = {}\n        select_keys = [\"input_ids\", \"responses\", \"attention_mask\", \"position_ids\", \"values\", \"returns\", \"response_mask\"]\n\n        \n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.keys()\n\n        # Split to make minibatch iterator for updating the actor\n        # See PPO paper for details. https://arxiv.org/abs/1707.06347\n        if has_multi_modal_inputs:\n            num_mini_batches = data.batch_size[0] // self.config.ppo_mini_batch_size\n            select_keys.append(\"multi_modal_inputs\")\n            dataloader = data.select(*select_keys).chunk(num_mini_batches)\n        else:\n            batch = data.select(*select_keys)\n            dataloader = batch.split(self.config.ppo_mini_batch_size)\n\n        for epoch in range(self.config.ppo_epochs):\n            for batch_idx, data in enumerate(dataloader):\n                # split batch into micro_batches\n                mini_batch = data\n                if has_multi_modal_inputs:\n                    num_micro_batches = mini_batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu\n                    micro_batches = data.select(*select_keys).chunk(num_micro_batches)\n                elif self.config.use_dynamic_bsz:\n                    max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n                    micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)\n                else:\n                    micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)\n                self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu\n\n                self.critic_optimizer.zero_grad()\n\n                for data in micro_batches:\n                    # Support all devices\n                    data = data.to(get_device_id())  # critic device is cpu when using offload\n                    responses = data[\"responses\"]\n                    attention_mask = data[\"attention_mask\"]\n                    values = data[\"values\"]\n                    returns = data[\"returns\"]\n                    response_length = responses.size(1)\n                    response_mask = data[\"response_mask\"]\n\n                    vpreds = self._forward_micro_batch(data)\n\n                    # assert not torch.any(torch.isnan(vpreds)).item()\n\n                    vf_loss, vf_clipfrac = core_algos.compute_value_loss(\n                        vpreds=vpreds,\n                        values=values,\n                        returns=returns,\n                        response_mask=response_mask,\n                        cliprange_value=self.config.cliprange_value,\n                        loss_agg_mode=self.config.loss_agg_mode,\n                    )\n                    if self.config.use_dynamic_bsz:\n                        # relative to the dynamic bsz\n                        loss = vf_loss * (len(data) / self.config.ppo_mini_batch_size)\n                    else:\n                        loss = vf_loss / self.gradient_accumulation\n\n                    loss.backward()\n\n                    data = {\n                        \"critic/vf_loss\": vf_loss.detach().item(),\n                        \"critic/vf_clipfrac\": vf_clipfrac.detach().item(),\n                        \"critic/vpred_mean\": masked_mean(vpreds, response_mask).detach().item(),\n                    }\n\n                    append_to_dict(metrics, data)\n\n                grad_norm = self._optimizer_step()\n                data = {\"critic/grad_norm\": grad_norm.detach().item()}\n                append_to_dict(metrics, data)\n        self.critic_optimizer.zero_grad()\n        return metrics\n"
  },
  {
    "path": "siirl/engine/critic/megatron_critic.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nImplement a multiprocess PPOCritic\n\"\"\"\n\nimport itertools\nfrom functools import partial\nfrom typing import Iterable\nfrom loguru import logger\n\nimport torch\nimport torch.distributed\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core.optimizer import DistributedOptimizer, OptimizerConfig\nfrom megatron.core.pipeline_parallel import get_forward_backward_func\nfrom omegaconf import OmegaConf\nfrom torch import nn\nfrom tensordict import TensorDict\nfrom siirl.dag_worker import core_algos\nfrom siirl.utils.debug import GPUMemoryLogger\nfrom siirl.utils.extras.device import get_device_id, get_torch_device\nfrom siirl.utils.megatron.pipeline_parallel import make_batch_generator\nfrom siirl.utils.extras.py_functional import append_to_dict\nfrom siirl.utils.model_utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\nfrom siirl.utils.model_utils.torch_functional import broadcast_dict_tensor, masked_mean\nfrom siirl.engine.critic import BasePPOCritic\n\n\nclass MegatronPPOCritic(BasePPOCritic):\n    def __init__(\n        self,\n        config,\n        model_config,\n        hf_config,\n        tf_config,\n        critic_module: nn.ModuleList,\n        critic_optimizer: DistributedOptimizer,\n        critic_optimizer_config: OptimizerConfig,\n    ):\n        super().__init__(config=config)\n        self._validate_config(config)\n        self.model_config = model_config\n        self.hf_config = hf_config  # huggingface config\n        self.tf_config = tf_config  # mcore transformer config\n\n        self.critic_module = critic_module\n        self.critic_optimizer = critic_optimizer\n        self.critic_optimizer_config = critic_optimizer_config\n\n        # we create a separate nametuple for optimizer step so that global args won't affect it.\n        self.optimizer_step_args = OmegaConf.create(\n            {\n                \"skip_grad\": None,\n                \"overlap_dp_param_comm\": False,\n                \"overlap_dp_grad_comm\": False,\n                \"gradient_accumulation_steps\": 1,\n                \"sequence_parallel\": self.tf_config.sequence_parallel,\n                \"DDP_impl\": \"local\",\n                \"layernorm_allreduce_bucket_threshold\": 0,\n                \"pipeline_model_parallel_split_rank\": None,\n                \"reduce_grads_use_alltoall\": False,\n            }\n        )\n\n    def _validate_config(self, config) -> None:\n        \"\"\"Validate config options not implemented for Megatron backend\"\"\"\n        assert config.ulysses_sequence_parallel_size == 1\n        if config.shuffle:\n            assert config.data_loader_seed is not None, \"If shuffle dataloader, seed must be manually set\"\n        if config.megatron.tensor_model_parallel_size == 1:\n            print(\"[Warining] Because critic tp size == 1, set sp to False\")\n            config.megatron.sequence_parallel = False\n        self.config = config\n\n    @GPUMemoryLogger(\"megatron critic\", logger=logger)\n    def compute_values(self, data: TensorDict) -> TensorDict:\n        data.to(get_device_id())\n        responses = data[\"responses\"]\n        attention_mask = data[\"attention_mask\"]\n        use_dynamic_bsz = data[\"use_dynamic_bsz\"]\n        micro_batch_size = data[\"micro_batch_size\"]\n        max_token_len = data[\"max_token_len\"]\n        assert micro_batch_size is not None, \"micro batch size is needed for forward compute\"\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"max_token_len must be set when use_dynamic_bsz is True\"\n            max_token_len = max_token_len * self.config.megatron.context_parallel_size\n        response_length = responses.size(1)\n        with torch.no_grad():\n            output = self.forward_backward_batch(data=data, forward_only=True, use_dynamic_bsz=use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len, mini_batch_size=None)\n            if mpu.is_pipeline_last_stage(ignore_virtual=True):\n                # only on last rank. It should be on every tp rank\n                values = [o[\"vpreds\"] for o in output[\"output\"]]  # (bs, seq_size, vocal_size)\n                values = torch.cat(values, dim=0).to(torch.float32)\n                if use_dynamic_bsz:\n                    indices = output[\"indices\"]\n                    indices = list(itertools.chain.from_iterable(indices))\n                    assert len(indices) == values.size(0), f\"{len(indices)} vs. {values.size()}\"\n                    revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n                    values = values[revert_indices]\n            else:\n                values = torch.empty_like(attention_mask, dtype=torch.float32)\n\n            # each tp ranks should contain the same value\n            values = values[:, -response_length - 1 : -1]  # Values are predicted at the ends of prefixes, e.g., the last prompt token\n            response_mask = data[\"response_mask\"]\n\n            values = values * response_mask  # Only action tokens have values\n            values = values.contiguous()\n\n            # sync among pp ranks\n            torch.distributed.broadcast(\n                tensor=values,\n                src=mpu.get_pipeline_model_parallel_last_rank(),\n                group=mpu.get_pipeline_model_parallel_group(),\n            )\n\n        # add empty cache after each compute\n        get_torch_device().empty_cache()\n\n        return values\n\n\n    def forward_backward_batch(self, data: TensorDict, forward_only=False, use_dynamic_bsz=False, micro_batch_size=None, max_token_len=None, mini_batch_size=None):\n        # broadcast from last pp rank to all other pp ranks\n        mini_batch = data\n        mini_batch.to(get_device_id())\n        mini_batch = mini_batch.contiguous()\n        broadcast_dict_tensor(mini_batch, src=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group())\n        \n        # broadcast from cp rank 0 to all other cp ranks to ensure same data across CP group\n        cp_size = mpu.get_context_parallel_world_size()\n        if cp_size > 1:\n            cp_group = mpu.get_context_parallel_group()\n            cp_group_ranks = torch.distributed.get_process_group_ranks(cp_group)\n            src_rank = cp_group_ranks[0]  # cp_rank=0 in this group\n            broadcast_dict_tensor(\n                mini_batch,\n                src=src_rank,\n                group=cp_group,\n            )\n        \n        # split into micro-batches\n        mini_batch[\"attention_mask\"] = mini_batch[\"attention_mask\"].to(bool)\n\n        indices = None\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"max_token_len must be set when use_dynamic_bsz is True\"\n            vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n            if vpp_size is not None and vpp_size > 1:\n                microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage\n                micro_batches, indices = rearrange_micro_batches(batch=mini_batch, num_batches_divided_by=microbatch_group_size_per_vp_stage, max_token_len=max_token_len)\n                assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, f\"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage {microbatch_group_size_per_vp_stage} for megatron backend\"\n            else:\n                micro_batches, indices = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)\n            total_seqlen = max_token_len\n        else:\n            assert micro_batch_size is not None, \"micro_batch_size is needed to be passed in when not using dynamic batch size\"\n            micro_batches = mini_batch.split(micro_batch_size)\n            seq_len = micro_batches[0][\"input_ids\"].shape[1]\n            total_seqlen = micro_batch_size * seq_len\n        n_micro_batch = len(micro_batches)\n\n        forward_backward_func = get_forward_backward_func()\n\n        def loss_func(output, data, meta_info):\n            nonlocal use_dynamic_bsz\n\n            if forward_only:\n                return torch.tensor(1.0, device=output.device), {\"vpreds\": output}\n\n            responses = data[\"responses\"]\n            attention_mask = data[\"attention_mask\"]\n            values = data[\"values\"]\n            returns = data[\"returns\"]\n            response_length = responses.size(1)\n            response_mask = data[\"response_mask\"]\n\n            cliprange_value = self.config.cliprange_value\n\n            vpreds = output  # (bs, sequence_length)\n            vpreds = vpreds[:, -response_length - 1 : -1]\n\n            vf_loss, vf_clipfrac = core_algos.compute_value_loss(\n                vpreds=vpreds,\n                values=values,\n                returns=returns,\n                response_mask=response_mask,\n                cliprange_value=cliprange_value,\n                loss_agg_mode=self.config.loss_agg_mode,\n            )\n\n            stats = {\n                \"critic/vf_loss\": vf_loss.detach().item(),\n                \"critic/vf_clipfrac\": vf_clipfrac.detach().item(),\n                \"critic/vpred_mean\": masked_mean(vpreds, response_mask).detach().item(),\n            }\n\n            return vf_loss, stats\n\n        def forward_step(batch_iter, model):\n            batch = next(batch_iter)\n            input_ids = batch[\"input_ids\"]\n            attention_mask = batch[\"attention_mask\"]\n            position_ids = batch[\"position_ids\"]\n            from siirl.models.mcore import get_mcore_forward_fn\n\n            forward_fn = get_mcore_forward_fn(self.hf_config)\n\n            output = forward_fn(\n                model,\n                input_ids,\n                attention_mask,\n                position_ids,\n                sequence_parallel=self.tf_config.sequence_parallel,\n                value_model=True,\n            )\n\n            return output, partial(loss_func, data=batch, meta_info={})\n\n        # batch should be a list of batches inside micro-batches\n        batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.critic_module))\n\n        # TODO: we may use the new schedule instead\n        # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)\n        if mpu.get_pipeline_model_parallel_world_size() > 1:\n            losses_reduced = forward_backward_func(\n                forward_step_func=forward_step,\n                data_iterator=batch_generator,\n                model=self.critic_module,\n                num_microbatches=n_micro_batch,\n                seq_length=total_seqlen,  # no use when input_shapes was set\n                micro_batch_size=1,  # no use when input_shapes was set\n                forward_only=forward_only,\n            )\n        else:\n            losses_reduced = forward_backward_func(\n                forward_step_func=forward_step,\n                data_iterator=batch_generator,\n                model=self.critic_module,\n                num_microbatches=n_micro_batch,\n                seq_length=total_seqlen,  # in use for pp = 1\n                micro_batch_size=1,  # in use for pp = 1\n                forward_only=forward_only,\n            )\n        # loss_reduces contains the stats returned from loss_func\n        losses_reduced = {\"output\": losses_reduced}\n        if use_dynamic_bsz:\n            losses_reduced[\"indices\"] = indices\n        return losses_reduced\n\n    @GPUMemoryLogger(\"megatron critic\", logger=logger)\n    def update_critic(self, data: TensorDict):\n        metrics = {}\n        select_keys = [\"input_ids\", \"responses\", \"attention_mask\", \"position_ids\", \"values\", \"returns\", \"response_mask\"]\n        has_multi_modal_inputs = \"multi_modal_inputs\" in data.keys()\n\n        # Split to make minibatch iterator for updating the actor\n        # See PPO paper for details. https://arxiv.org/abs/1707.06347\n        if has_multi_modal_inputs:\n            num_mini_batches = data.batch_size[0] // self.config.ppo_mini_batch_size\n            select_keys.append(\"multi_modal_inputs\")\n            dataloader = data.select(*select_keys).chunk(num_mini_batches)\n        else:\n            batch = data.select(*select_keys)\n            dataloader = batch.split(self.config.ppo_mini_batch_size)\n\n        for epoch in range(self.config.ppo_epochs):\n            for batch_idx, data in enumerate(dataloader):\n                # data = data.batch.to(self.critic_module.device)\n                self.critic_optimizer.zero_grad()\n                # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm\n                for chunk in self.critic_module:\n                    chunk.zero_grad_buffer()\n                micro_batch_size = self.config.ppo_micro_batch_size_per_gpu\n                max_token_len = None\n                if self.config.use_dynamic_bsz:\n                    max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size\n                metric_micro_batch = self.forward_backward_batch(data, forward_only=False, use_dynamic_bsz=self.config.use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len, mini_batch_size=self.config.ppo_mini_batch_size)\n                metric_micro_batch = metric_micro_batch[\"output\"]\n                update_successful, grad_norm, num_zeros_in_grad = self.critic_optimizer.step()\n                learning_rate = self.critic_optimizer.param_groups[-1][\"lr\"]\n                data = {\"critic/grad_norm\": grad_norm, \"critic/lr\": learning_rate}\n                append_to_dict(metrics, data)\n\n                if update_successful:\n                    # allgather already execute in optimizer.step in new megatron\n                    pass\n                else:\n                    raise NotImplementedError\n\n                for metric in metric_micro_batch:\n                    append_to_dict(metrics, metric)  # append the metric from this micro-batch to global metrics.\n\n        # add empty cache after each compute\n        get_torch_device().empty_cache()\n        return metrics\n"
  },
  {
    "path": "siirl/engine/fsdp_workers.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe main entry point to run the PPO algorithm\n\"\"\"\n\nimport json\nimport os\nimport warnings\nfrom dataclasses import asdict\nfrom typing import Union, Optional\n\nimport psutil\nimport torch\nimport torch.distributed\nfrom codetiming import Timer\nfrom loguru import logger\nfrom omegaconf import DictConfig\nfrom peft import LoraConfig, TaskType, get_peft_model\nfrom safetensors.torch import save_file\nfrom torch.distributed import ProcessGroup, init_device_mesh\nfrom torch.distributed.device_mesh import DeviceMesh\nfrom tensordict import TensorDict\n# Handle different tensordict versions - NonTensorData location varies\ntry:\n    from tensordict import NonTensorData\nexcept ImportError:\n    from tensordict.tensorclass import NonTensorData\nimport siirl.utils.model_utils.torch_functional as F\nfrom typing import Any, Dict, List, Optional, Union, Set\nfrom siirl.models.loader import load_tokenizer\nfrom siirl.engine.base_worker import Worker\nfrom siirl.execution.scheduler.enums import Role\nfrom siirl.execution.scheduler.enums import Role\nfrom siirl.utils.model_utils.activation_offload import enable_activation_offloading\nfrom siirl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager\nfrom siirl.utils.extras.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device, is_cuda_available, is_npu_available\nfrom siirl.utils.model_utils.flops_counter import FlopsCounter\nfrom siirl.utils.extras.fs import copy_to_local\nfrom siirl.utils.model_utils.fsdp_utils import (\n    CPUOffloadPolicy,\n    MixedPrecisionPolicy,\n    apply_fsdp2,\n    fsdp2_load_full_state_dict,\n    fsdp_version,\n    get_fsdp_wrap_policy,\n    get_init_weight_context_manager,\n    init_fn,\n    layered_summon_lora_params,\n    load_fsdp_model_to_gpu,\n    load_fsdp_optimizer,\n    offload_fsdp_model_to_cpu,\n    offload_fsdp_optimizer,\n)\nfrom siirl.utils.extras.import_utils import import_external_libs\nfrom siirl.utils.model_utils.model import compute_position_id_with_mask\nfrom siirl.utils.extras.py_functional import convert_to_regular_types\nfrom siirl.params.model_args import ActorRolloutRefArguments, CriticArguments, FSDPArguments, OptimizerArguments, RewardModelArguments\nfrom siirl.engine.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager\n\ndevice_name = get_device_name()\n\n\ndef create_device_mesh_from_group(\n    process_group: ProcessGroup,\n    fsdp_size: int = 1,\n    sp_size: int = 1,\n) -> DeviceMesh:\n    \"\"\"Creates a DeviceMesh from a process group for specific parallel strategies.\n\n    This function configures a DeviceMesh based on the provided parallelism sizes.\n    It supports three mutually exclusive parallelism configurations:\n    1.  Data Parallelism + Sequence Parallelism ([dp, sp]): Activated when `sp_size` > 1.\n    2.  Fully Sharded Data Parallelism ([fsdp]): Activated when `fsdp_size` is 1.\n    3.  Distributed Data Parallelism + FSDP ([ddp, fsdp]): Activated when `fsdp_size` > 1.\n\n    Note:\n        `sp_size > 1` and `fsdp_size > 1` cannot be used simultaneously.\n\n    Args:\n        process_group (ProcessGroup): The base process group for the mesh.\n        fsdp_size (int): The size of the FSDP dimension. Activates FSDP modes.\n                         Defaults to 1.\n        sp_size (int): The size of the Sequence Parallel dimension. Activates [dp, sp]\n                       mode if > 1. Defaults to 1.\n\n    Returns:\n        DeviceMesh: A configured DeviceMesh object for the specified topology.\n\n    Raises:\n        ValueError: If inputs are invalid, parallelism strategies are mixed,\n                    or the group size is not compatible with the requested\n                    parallelism dimensions.\n    \"\"\"\n    if process_group is None:\n        raise ValueError(\"`process_group` cannot be None.\")\n\n    if sp_size > 1 and fsdp_size > 1:\n        raise ValueError(\"Sequence Parallelism (sp_size > 1) and FSDP (fsdp_size > 1) are mutually exclusive and cannot be activated simultaneously.\")\n\n    import torch.distributed\n\n    device_type = get_device_name()\n    group_size = torch.distributed.get_world_size(group=process_group)\n    ranks_in_group = torch.distributed.get_process_group_ranks(process_group)\n\n    # --- 2. [dp, sp] Mode ---\n    if sp_size > 1:\n        if group_size % sp_size != 0:\n            raise ValueError(f\"For [dp, sp] mode, the process group size ({group_size}) must be divisible by sp_size ({sp_size}).\")\n        dp_size = group_size // sp_size\n        mesh_shape = (dp_size, sp_size)\n        mesh_dim_names = (\"dp\", \"sp\")\n        logger.info(f\"Creating [dp, sp] DeviceMesh with shape {mesh_shape}.\")\n\n        rank_mesh = torch.tensor(ranks_in_group, dtype=torch.long).view(mesh_shape)\n        return DeviceMesh(device_type, rank_mesh, mesh_dim_names=mesh_dim_names)\n\n    # --- 3. FSDP / DDP Modes ---\n    if fsdp_size < 0 or fsdp_size >= group_size:\n        # Pure FSDP (equivalent to DDP over the whole group).\n        # This creates a 1D mesh representing a single shard group over all ranks.\n        logger.info(\"Creating pure [fsdp] DeviceMesh from the process group.\")\n        # mesh_tensor = torch.tensor(ranks_in_group)\n        return DeviceMesh.from_group(group=process_group, device_type=device_type, mesh=ranks_in_group, mesh_dim_names=(\"fsdp\",))\n\n    # [ddp, fsdp] mode for fsdp_size > 1\n    if group_size % fsdp_size != 0:\n        raise ValueError(f\"The process group size ({group_size}) must be divisible by fsdp_size ({fsdp_size}).\")\n\n    # [ddp, fsdp] mode for fsdp_size > 1\n    ddp_size = group_size // fsdp_size\n    mesh_shape = (ddp_size, fsdp_size)\n    mesh_dim_names = (\"ddp\", \"fsdp\")\n    logger.info(f\"Creating [ddp, fsdp] DeviceMesh with shape {mesh_shape}.\")\n\n    rank_mesh = torch.tensor(ranks_in_group, dtype=torch.long).view(mesh_shape)\n    # TODO: support 2D process group(List)\n    return DeviceMesh(device_type=device_type, mesh=rank_mesh, mesh_dim_names=mesh_dim_names)\n\n\ndef get_sharding_strategy(device_mesh):\n    from torch.distributed.fsdp import ShardingStrategy\n\n    if device_mesh.ndim == 1:\n        sharding_strategy = ShardingStrategy.FULL_SHARD\n    elif device_mesh.ndim == 2:\n        sharding_strategy = ShardingStrategy.HYBRID_SHARD\n    else:\n        raise NotImplementedError(f\"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2\")\n    return sharding_strategy\n\n\nclass ActorRolloutRefWorker(Worker):\n    \"\"\"\n    This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy\n    or a hybrid engine based on the config.rollout\n    \"\"\"\n\n    def __init__(self, config: ActorRolloutRefArguments, role: str, process_group: ProcessGroup):\n        super().__init__()\n        self.config = config\n\n        import torch.distributed\n\n        if not torch.distributed.is_initialized():\n            rank = int(os.environ.get(\"RANK\", 0))\n            world_size = int(os.environ.get(\"WORLD_SIZE\", 1))\n            torch.distributed.init_process_group(backend=f\"cpu:gloo,{get_device_name()}:{get_nccl_backend()}\", rank=rank, world_size=world_size)\n\n        self.group_world_size = torch.distributed.get_world_size(group=process_group)\n        # build device mesh for FSDP\n        # TODO(sgm): support FSDP hybrid shard for larger model\n        self.device_mesh = create_device_mesh_from_group(process_group=process_group, fsdp_size=self.config.actor.fsdp_config.fsdp_size)\n\n        # build device mesh for Ulysses Sequence Parallel\n        self.ulysses_device_mesh = None\n        self.ulysses_sequence_parallel_size = self.config.actor.ulysses_sequence_parallel_size\n        if self.ulysses_sequence_parallel_size > 1:\n            self.ulysses_device_mesh = create_device_mesh_from_group(process_group=process_group, sp_size=self.ulysses_sequence_parallel_size)\n\n        self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)\n        self._lora_rank = self.config.model.lora_rank\n        self._is_lora = self._lora_rank > 0\n\n        self.role = role\n        assert self.role in [\"actor\", \"rollout\", \"ref\", \"actor_rollout\", \"actor_rollout_ref\"]\n\n        self._is_actor = self.role in [\"actor\", \"actor_rollout\", \"actor_rollout_ref\"]\n        self._is_rollout = self.role in [\"rollout\", \"actor_rollout\", \"actor_rollout_ref\"]\n        self._is_ref = self.role in [\"ref\", \"actor_rollout_ref\"]\n\n        self._is_offload_param = False\n        self._is_offload_optimizer = False\n        if self._is_actor:\n            self._is_offload_param = self.config.actor.fsdp_config.param_offload\n            self._is_offload_optimizer = self.config.actor.fsdp_config.optimizer_offload\n        elif self._is_ref:\n            # TODO: it seems that manual offload is slowly than FSDP offload\n            self._is_offload_param = self.config.ref.fsdp_config.param_offload\n\n        # normalize config\n        if self._is_actor:\n            self.config.actor.ppo_mini_batch_size *= self.config.rollout.n\n            self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size\n            assert self.config.actor.ppo_mini_batch_size > 0, f\"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after normalization\"\n            # micro bsz\n            if self.config.actor.ppo_micro_batch_size is not None:\n                self.config.actor.ppo_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size\n                self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size\n\n            if self.config.actor.ppo_micro_batch_size_per_gpu is not None:\n                assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, \\\n                    f\"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by \" \\\n                    f\"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}\"\n                assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, \\\n                    f\"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than \" \\\n                    f\"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}\"\n\n        # normalize rollout config\n        if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None:\n            self.config.rollout.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size\n            self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size\n        # normalize ref config\n        if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:\n            self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size\n            self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size\n\n    def _build_model_optimizer(\n        self,\n        model_path: str,\n        fsdp_config: FSDPArguments,\n        optim_config: Optional[OptimizerArguments],\n        override_model_config: DictConfig,\n        use_remove_padding: bool = False,\n        use_fused_kernels: bool = False,\n        enable_gradient_checkpointing: bool = False,\n        trust_remote_code: bool = False,\n        use_liger: bool = False,\n        role: Role = Role.Actor,\n        enable_activation_offload: bool = False,\n    ):\n        \"\"\"\n        Build model and optimizer (Refactored version).\n        \n        This method orchestrates the model building process through 4 main steps:\n            1. Prepare and load model from pretrained checkpoint\n            2. Apply model modifications (LoRA, gradient checkpointing, etc.)\n            3. Wrap model with FSDP\n            4. Create optimizer and learning rate scheduler (Actor only)\n        \n        Args:\n            model_path: Path to model checkpoint\n            fsdp_config: FSDP configuration\n            optim_config: Optimizer configuration (optional)\n            override_model_config: Config overrides\n            use_remove_padding: Whether to use remove padding\n            use_fused_kernels: Whether to use fused kernels\n            enable_gradient_checkpointing: Whether to enable gradient checkpointing\n            trust_remote_code: Whether to trust remote code\n            use_liger: Whether to apply Liger kernel\n            role: Role (Actor or RefPolicy)\n            enable_activation_offload: Whether to enable activation offload\n        \n        Returns:\n            Tuple of (model_fsdp, optimizer, lr_scheduler, model_config)\n        \"\"\"\n        from siirl.utils.model_utils.model import print_model_size\n        \n        assert role in [Role.Actor, Role.RefPolicy]\n        \n        # Step 1: Prepare and load model\n        actor_module, actor_model_config, torch_dtype = self._prepare_and_load_model(\n            model_path=model_path,\n            fsdp_config=fsdp_config,\n            override_model_config=override_model_config,\n            trust_remote_code=trust_remote_code,\n            role=role,\n        )\n        \n        # Step 2: Apply model modifications\n        actor_module = self._apply_model_modifications(\n            model=actor_module,\n            use_liger=use_liger,\n            use_remove_padding=use_remove_padding,\n            use_fused_kernels=use_fused_kernels,\n            enable_gradient_checkpointing=enable_gradient_checkpointing,\n            torch_dtype=torch_dtype,\n        )\n        \n        torch.distributed.barrier()\n        if self.rank == 0:\n            print_model_size(actor_module)\n        \n        # Step 3: Wrap model with FSDP\n        actor_module_fsdp = self._setup_fsdp_wrapper(\n            model=actor_module,\n            fsdp_config=fsdp_config,\n            role=role,\n            enable_activation_offload=enable_activation_offload,\n            enable_gradient_checkpointing=enable_gradient_checkpointing,\n        )\n        \n        # Step 4: Create optimizer and scheduler\n        actor_optimizer, actor_lr_scheduler = self._create_optimizer_and_scheduler(\n            model=actor_module_fsdp,\n            optim_config=optim_config,\n            role=role,\n        )\n        \n        return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config\n\n    # ==================================================================================\n    # Legacy implementation (kept for rollback, remove after verification)\n    # ==================================================================================\n    def _build_model_optimizer_legacy(\n        self,\n        model_path: str,\n        fsdp_config: FSDPArguments,\n        optim_config: Optional[OptimizerArguments],\n        override_model_config: DictConfig,\n        use_remove_padding: bool = False,\n        use_fused_kernels: bool = False,\n        enable_gradient_checkpointing: bool = False,\n        trust_remote_code: bool = False,\n        use_liger: bool = False,\n        role: Role = Role.Actor,\n        enable_activation_offload: bool = False,\n    ):\n        from torch import optim\n        from torch.distributed.fsdp import CPUOffload, MixedPrecision\n        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n        from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoImageProcessor, AutoProcessor\n\n        from siirl.utils.model_utils.model import get_generation_config, print_model_size, update_model_config\n        from siirl.utils.model_utils.torch_dtypes import PrecisionType\n\n        assert role in [Role.Actor, Role.RefPolicy]\n\n        local_path = model_path\n\n        if self.config.model.model_type == \"embodied\":\n            if self.config.embodied.embodied_type == \"openvla-oft\":\n                from siirl.models.embodied.openvla_oft.configuration_prismatic import OpenVLAConfig\n                from siirl.models.embodied.openvla_oft.modeling_prismatic import OpenVLAForActionPrediction\n                from siirl.models.embodied.openvla_oft.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor\n                \n                AutoConfig.register(\"openvla\", OpenVLAConfig)\n                AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)\n                AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)\n                AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)\n                if self.rank == 0:\n                    try:\n                        from siirl.utils.embodied.openvla_utils import update_auto_map, check_model_logic_mismatch\n                        logger.info(f\"[rank-{self.rank}] Updating auto_map for OpenVLA-OFT at {local_path}\")\n                        update_auto_map(local_path)\n                        check_model_logic_mismatch(local_path)\n                        logger.info(f\"[rank-{self.rank}] Successfully updated auto_map for OpenVLA-OFT\")\n                    except Exception as e:\n                        logger.error(f\"[rank-{self.rank}] Failed to update auto_map for OpenVLA-OFT: {e}\")\n                        raise\n                torch.distributed.barrier()\n            elif self.config.embodied.embodied_type == \"openvla\":\n                from siirl.models.embodied.openvla.configuration_prismatic import OpenVLAConfig\n                from siirl.models.embodied.openvla.modeling_prismatic import OpenVLAForActionPrediction\n                from siirl.models.embodied.openvla.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor\n                \n                AutoConfig.register(\"openvla\", OpenVLAConfig)\n                AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)\n                AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)\n                AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)\n                if self.rank == 0:\n                    try:\n                        from siirl.utils.embodied.openvla_utils import update_auto_map, check_model_logic_mismatch\n                        logger.info(f\"[rank-{self.rank}] Updating auto_map for OpenVLA at {local_path}\")\n                        update_auto_map(local_path)\n                        check_model_logic_mismatch(local_path)\n                        logger.info(f\"[rank-{self.rank}] Successfully updated auto_map for OpenVLA\")\n                    except Exception as e:\n                        logger.error(f\"[rank-{self.rank}] Failed to update auto_map for OpenVLA: {e}\")\n                        raise\n                torch.distributed.barrier()\n            else:\n                raise ValueError(f\"Invalid vla type: {self.config.embodied.embodied_type}\")\n\n        torch_dtype = fsdp_config.model_dtype\n        if torch_dtype is None:\n            torch_dtype = torch.float32 if self._is_actor else torch.bfloat16\n        else:\n            torch_dtype = PrecisionType.to_dtype(torch_dtype)\n\n        # override model kwargs\n        if self.config.model.model_type == \"embodied\" and self.config.embodied.embodied_type == \"openvla-oft\":\n            actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)\n        else:\n            actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code, attn_implementation=\"flash_attention_2\")\n        if self._is_ref:\n            self.flops_counter = FlopsCounter(actor_model_config, forward_only=True)\n        # patch for kimi-vl\n        if getattr(actor_model_config, \"model_type\", None) == \"kimi_vl\":\n            actor_model_config.text_config.topk_method = \"greedy\"\n\n        self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)\n\n        override_config_kwargs = {\n            \"bos_token_id\": self.tokenizer.bos_token_id,\n            \"eos_token_id\": self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.tokenizer.pad_token_id,\n        }\n        override_config_kwargs.update(override_model_config)\n        update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)\n        # if self.rank == 0:\n        #     logger.info(f\"Model config after override: {actor_model_config}\")\n\n        # NOTE(fix me): tie_word_embedding causes meta_tensor init to hang\n        init_context = get_init_weight_context_manager(use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh)\n\n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            is_embodied_model = self.config.model.model_type == \"embodied\"\n            if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys() or is_embodied_model:\n                actor_module_class = AutoModelForVision2Seq\n            elif type(actor_model_config).__name__ in \"configuration_internvl_chat.InternVLChatConfig\":\n                from siirl.models.transformers.internvl_chat import InternVLChatModel\n\n                actor_module_class = InternVLChatModel\n                logger.info(\"Choose InternVLChatModel for internvl\")\n            else:\n                actor_module_class = AutoModelForCausalLM\n\n            \n            \n            if is_embodied_model and self.config.embodied.embodied_type == \"openvla-oft\":\n                # OpenVLA-OFT: No flash_attention_2, requires additional setup\n                logger.info(\"Loading OpenVLA-OFT model (without flash_attention_2)\")\n                actor_module = actor_module_class.from_pretrained(\n                    pretrained_model_name_or_path=local_path,\n                    torch_dtype=torch_dtype,\n                    config=actor_model_config,\n                    trust_remote_code=trust_remote_code,\n                )\n                \n                # Set the number of images in input for multi-camera support\n                if hasattr(actor_module, 'vision_backbone'):\n                    num_images = getattr(self.config.embodied, 'num_images_in_input', 1)\n                    actor_module.vision_backbone.set_num_images_in_input(num_images)\n                    logger.info(f\"Set vision_backbone.num_images_in_input = {num_images}\")\n                \n                # Load dataset statistics for action normalization\n                dataset_statistics_path = os.path.join(local_path, \"dataset_statistics.json\")\n                if os.path.isfile(dataset_statistics_path):\n                    with open(dataset_statistics_path, \"r\") as f:\n                        norm_stats = json.load(f)\n                    actor_module.norm_stats = norm_stats\n                    logger.info(f\"Loaded dataset_statistics.json with {len(norm_stats)} task(s)\")\n                else:\n                    logger.warning(\n                        \"WARNING: No dataset_statistics.json file found for OpenVLA-OFT checkpoint.\\n\"\n                        \"You can ignore this if loading the base VLA checkpoint (not fine-tuned).\\n\"\n                        \"Otherwise, you may encounter errors when calling predict_action() due to missing unnorm_key.\"\n                    )\n                    \n            elif is_embodied_model and self.config.embodied.embodied_type == \"openvla\":\n                # OpenVLA: Use flash_attention_2 for efficiency\n                logger.info(\"Loading OpenVLA model (with flash_attention_2)\")\n                actor_module = actor_module_class.from_pretrained(\n                    pretrained_model_name_or_path=local_path,\n                    torch_dtype=torch_dtype,\n                    attn_implementation=\"flash_attention_2\",\n                    config=actor_model_config,\n                    trust_remote_code=trust_remote_code,\n                )\n            else:\n                # Default loading for non-VLA models\n                actor_module = actor_module_class.from_pretrained(\n                    pretrained_model_name_or_path=local_path,\n                    torch_dtype=torch_dtype,\n                    config=actor_model_config,\n                    trust_remote_code=trust_remote_code,\n                )\n\n            # Apply Liger kernel to the model if use_liger is set to True\n            if use_liger:\n                from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance\n                _apply_liger_kernel_to_instance(model=actor_module)\n            \n            if not is_embodied_model:\n                from siirl.models.transformers.monkey_patch import apply_monkey_patch\n                apply_monkey_patch(\n                    model=actor_module,\n                    use_remove_padding=use_remove_padding,\n                    ulysses_sp_size=self.ulysses_sequence_parallel_size,\n                    use_fused_kernels=use_fused_kernels,\n                )\n\n            # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2\n            actor_module.to(torch_dtype)\n\n            if enable_gradient_checkpointing:\n                actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n            if self._is_lora:\n                logger.info(\"Applying LoRA to actor module\")\n                actor_module.enable_input_require_grads()\n                # Convert config to regular Python types before creating PEFT model\n                lora_config = {\"task_type\": TaskType.CAUSAL_LM, \n                               \"r\": self.config.model.lora_rank, \n                               \"lora_alpha\": self.config.model.lora_alpha, \n                               \"target_modules\": convert_to_regular_types(self.config.model.target_modules), \"bias\": \"none\"}\n                actor_module = get_peft_model(actor_module, LoraConfig(**lora_config))\n                \n        torch.distributed.barrier()\n\n        if self.rank == 0:\n            print_model_size(actor_module)\n\n        # We wrap FSDP for rollout as well\n        mixed_precision_config = fsdp_config.mixed_precision\n        if mixed_precision_config is not None:\n            param_dtype = PrecisionType.to_dtype(mixed_precision_config.param_dtype)\n            reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.reduce_dtype)\n            buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.buffer_dtype)\n        else:\n            param_dtype = torch.bfloat16\n            reduce_dtype = torch.float32\n            buffer_dtype = torch.float32\n\n        mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)\n\n        auto_wrap_policy = get_fsdp_wrap_policy(module=actor_module, config=fsdp_config.wrap_policy, is_lora=self.config.model.lora_rank > 0)\n\n        if self._is_rollout and self.config.rollout.name == \"hf\" and not is_embodied_model:\n            # TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma\n            auto_wrap_policy = None\n\n        if self.rank == 0:\n            logger.info(f\"wrap_policy: {auto_wrap_policy}\")\n\n        fsdp_mesh = self.device_mesh\n        sharding_strategy = get_sharding_strategy(fsdp_mesh)\n\n        # TODO: add transformer policy\n        # We force reference policy to use CPUOffload to save memory.\n        # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation\n        cpu_offload = None if role == Role.Actor else CPUOffload(offload_params=True)\n        fsdp_strategy = self.config.actor.strategy\n        if fsdp_strategy == \"fsdp\":\n            actor_module_fsdp = FSDP(\n                actor_module,\n                cpu_offload=cpu_offload,\n                param_init_fn=init_fn,\n                use_orig_params=False,\n                auto_wrap_policy=auto_wrap_policy,\n                device_id=get_device_id(),\n                sharding_strategy=sharding_strategy,  # zero3\n                mixed_precision=mixed_precision,\n                sync_module_states=True,\n                device_mesh=self.device_mesh,\n                forward_prefetch=False,\n            )\n        elif fsdp_strategy == \"fsdp2\":\n            assert CPUOffloadPolicy is not None, \"PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)\"\n            mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True)\n            if role == Role.Actor and fsdp_config.offload_policy:\n                cpu_offload = CPUOffloadPolicy(pin_memory=True)\n                self._is_offload_param = False\n                self._is_offload_optimizer = False\n            else:\n                cpu_offload = None if role == Role.Actor else CPUOffloadPolicy(pin_memory=True)\n\n            fsdp_kwargs = {\n                \"mesh\": fsdp_mesh,\n                \"mp_policy\": mp_policy,\n                \"offload_policy\": cpu_offload,\n                \"reshard_after_forward\": fsdp_config.reshard_after_forward,\n            }\n            full_state = actor_module.state_dict()\n            apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config)\n            fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload)\n            actor_module_fsdp = actor_module\n        else:\n            raise NotImplementedError(f\"not implement {fsdp_strategy}\")\n\n        if enable_activation_offload:\n            enable_activation_offloading(actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing)\n\n        # TODO: add more optimizer args into config\n        if role == Role.Actor and optim_config is not None:\n            from siirl.utils.model_utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup\n\n            actor_optimizer = optim.AdamW(\n                actor_module_fsdp.parameters(),\n                lr=optim_config.lr,\n                betas=optim_config.betas,\n                weight_decay=optim_config.weight_decay,\n            )\n\n            total_steps = optim_config.total_training_steps\n            num_warmup_steps = int(optim_config.lr_warmup_steps)\n            warmup_style = optim_config.warmup_style\n            min_lr_ratio = optim_config.min_lr_ratio if optim_config.min_lr_ratio else 0.0\n            num_cycles = optim_config.num_cycles\n            if num_warmup_steps < 0:\n                num_warmup_steps_ratio = optim_config.lr_warmup_steps_ratio\n                num_warmup_steps = int(num_warmup_steps_ratio * total_steps)\n\n            if self.rank == 0:\n                logger.info(f\"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}\")\n\n            if warmup_style == \"constant\":\n                actor_lr_scheduler = get_constant_schedule_with_warmup(optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps)\n            elif warmup_style == \"cosine\":\n                actor_lr_scheduler = get_cosine_schedule_with_warmup(optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps, \n                                                                     num_training_steps=total_steps, min_lr_ratio=min_lr_ratio, \n                                                                     num_cycles=num_cycles)\n            else:\n                raise NotImplementedError(f\"Warmup style {warmup_style} is not supported\")\n        else:\n            actor_optimizer = None\n            actor_lr_scheduler = None\n\n        return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config\n\n    def _prepare_and_load_model(\n        self,\n        model_path: str,\n        fsdp_config: FSDPArguments,\n        override_model_config: DictConfig,\n        trust_remote_code: bool,\n        role: Role,\n    ):\n        \"\"\"\n        Prepare configuration and load model from pretrained checkpoint.\n        \n        Steps:\n            1. Register embodied model classes if needed\n            2. Determine torch dtype\n            3. Load and configure model config\n            4. Load model from pretrained\n            5. Apply model-specific setup (e.g., OpenVLA-OFT)\n        \n        Args:\n            model_path: Path to model checkpoint\n            fsdp_config: FSDP configuration\n            override_model_config: Config overrides\n            trust_remote_code: Whether to trust remote code\n            role: Role (Actor or RefPolicy)\n        \n        Returns:\n            Tuple of (model, model_config, torch_dtype)\n        \"\"\"\n        from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq\n        from siirl.utils.model_utils.model import get_generation_config, update_model_config\n        from siirl.utils.model_utils.torch_dtypes import PrecisionType\n        \n        is_embodied = self.config.model.model_type == \"embodied\"\n        \n        # Register embodied model classes\n        if is_embodied:\n            self._register_embodied_model(model_path, trust_remote_code)\n        \n        # Determine torch dtype\n        if fsdp_config.model_dtype is None:\n            torch_dtype = torch.float32 if self._is_actor else torch.bfloat16\n        else:\n            torch_dtype = PrecisionType.to_dtype(fsdp_config.model_dtype)\n        \n        # Load model config with appropriate attention implementation\n        embodied_type = getattr(self.config.embodied, 'embodied_type', None) if is_embodied else None\n        use_flash_attn = not (is_embodied and embodied_type == \"openvla-oft\")\n        \n        config_kwargs = {\"trust_remote_code\": trust_remote_code}\n        if use_flash_attn:\n            config_kwargs[\"attn_implementation\"] = \"flash_attention_2\"\n        \n        model_config = AutoConfig.from_pretrained(model_path, **config_kwargs)\n        \n        # Initialize flops counter for reference policy\n        if self._is_ref:\n            self.flops_counter = FlopsCounter(model_config, forward_only=True)\n        \n        # Apply model-specific patches\n        if getattr(model_config, \"model_type\", None) == \"kimi_vl\":\n            model_config.text_config.topk_method = \"greedy\"\n        \n        # Load and update generation config\n        self.generation_config = get_generation_config(model_path, trust_remote_code)\n        \n        override_kwargs = {\n            \"bos_token_id\": self.tokenizer.bos_token_id,\n            \"eos_token_id\": self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.tokenizer.pad_token_id,\n        }\n        override_kwargs.update(override_model_config)\n        update_model_config(model_config, override_config_kwargs=override_kwargs)\n        \n        # Load model with appropriate context manager\n        init_context = get_init_weight_context_manager(\n            use_meta_tensor=not model_config.tie_word_embeddings,\n            mesh=self.device_mesh\n        )\n        \n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            \n            # Determine model class\n            model_class = self._get_model_class(model_config, is_embodied)\n            \n            # Load model based on type\n            if is_embodied and embodied_type == \"openvla-oft\":\n                logger.info(\"Loading OpenVLA-OFT model (without flash_attention_2)\")\n                model = model_class.from_pretrained(\n                    model_path,\n                    torch_dtype=torch_dtype,\n                    config=model_config,\n                    trust_remote_code=trust_remote_code,\n                )\n                self._setup_openvla_oft_model(model, model_path)\n                \n            elif is_embodied and embodied_type == \"openvla\":\n                logger.info(\"Loading OpenVLA model (with flash_attention_2)\")\n                model = model_class.from_pretrained(\n                    model_path,\n                    torch_dtype=torch_dtype,\n                    attn_implementation=\"flash_attention_2\",\n                    config=model_config,\n                    trust_remote_code=trust_remote_code,\n                )\n            else:\n                model = model_class.from_pretrained(\n                    model_path,\n                    torch_dtype=torch_dtype,\n                    config=model_config,\n                    trust_remote_code=trust_remote_code,\n                )\n        \n        return model, model_config, torch_dtype\n\n    def _get_model_class(self, model_config, is_embodied: bool):\n        \"\"\"Determine the appropriate model class based on config.\"\"\"\n        from transformers import AutoModelForCausalLM, AutoModelForVision2Seq\n        \n        if type(model_config) in AutoModelForVision2Seq._model_mapping.keys() or is_embodied:\n            return AutoModelForVision2Seq\n        elif type(model_config).__name__ == \"configuration_internvl_chat.InternVLChatConfig\":\n            from siirl.models.transformers.internvl_chat import InternVLChatModel\n            logger.info(\"Using InternVLChatModel for internvl\")\n            return InternVLChatModel\n        else:\n            return AutoModelForCausalLM\n\n    def _register_embodied_model(self, model_path: str, trust_remote_code: bool):\n        \"\"\"Register embodied model classes to transformers registry.\"\"\"\n        from transformers import AutoConfig, AutoModelForVision2Seq, AutoImageProcessor, AutoProcessor\n        \n        embodied_type = self.config.embodied.embodied_type\n        \n        if embodied_type not in [\"openvla-oft\", \"openvla\"]:\n            raise ValueError(f\"Unsupported embodied type: {embodied_type}\")\n        \n        # Import based on type\n        module_name = embodied_type.replace(\"-\", \"_\")\n        config_module = f\"siirl.models.embodied.{module_name}.configuration_prismatic\"\n        model_module = f\"siirl.models.embodied.{module_name}.modeling_prismatic\"\n        processor_module = f\"siirl.models.embodied.{module_name}.processing_prismatic\"\n        \n        # Dynamic import\n        from importlib import import_module\n        config_mod = import_module(config_module)\n        model_mod = import_module(model_module)\n        processor_mod = import_module(processor_module)\n        \n        # Register classes\n        AutoConfig.register(\"openvla\", config_mod.OpenVLAConfig)\n        AutoImageProcessor.register(config_mod.OpenVLAConfig, processor_mod.PrismaticImageProcessor)\n        AutoProcessor.register(config_mod.OpenVLAConfig, processor_mod.PrismaticProcessor)\n        AutoModelForVision2Seq.register(config_mod.OpenVLAConfig, model_mod.OpenVLAForActionPrediction)\n        \n        # Update automap on rank 0 (with file locking for safety)\n        # Note: update_auto_map now includes retry logic and atomic writes\n        if self.rank == 0:\n            try:\n                from siirl.utils.embodied.openvla_utils import update_auto_map, check_model_logic_mismatch\n                logger.info(f\"[rank-{self.rank}] Updating auto_map for {model_path}\")\n                update_auto_map(model_path)\n                check_model_logic_mismatch(model_path)\n                logger.info(f\"[rank-{self.rank}] Successfully updated auto_map\")\n            except Exception as e:\n                logger.error(f\"[rank-{self.rank}] Failed to update auto_map: {e}\")\n                raise\n        \n        # Synchronize all ranks before proceeding\n        torch.distributed.barrier()\n\n    def _setup_openvla_oft_model(self, model: torch.nn.Module, model_path: str):\n        \"\"\"Apply OpenVLA-OFT specific configurations.\"\"\"\n        import json\n        \n        # Setup multi-camera support\n        if hasattr(model, 'vision_backbone'):\n            num_images = getattr(self.config.embodied, 'num_images_in_input', 1)\n            model.vision_backbone.set_num_images_in_input(num_images)\n            logger.info(f\"Set vision_backbone.num_images_in_input={num_images}\")\n        \n        # Load dataset statistics for action normalization\n        stats_path = os.path.join(model_path, \"dataset_statistics.json\")\n        if os.path.isfile(stats_path):\n            with open(stats_path, \"r\") as f:\n                model.norm_stats = json.load(f)\n            logger.info(f\"Loaded dataset_statistics.json with {len(model.norm_stats)} task(s)\")\n        else:\n            logger.warning(\n                \"No dataset_statistics.json found for OpenVLA-OFT. \"\n                \"This is expected for base checkpoints but may cause errors for fine-tuned models.\"\n            )\n\n    def _apply_model_modifications(\n        self,\n        model: torch.nn.Module,\n        use_liger: bool,\n        use_remove_padding: bool,\n        use_fused_kernels: bool,\n        enable_gradient_checkpointing: bool,\n        torch_dtype: torch.dtype,\n    ):\n        \"\"\"\n        Apply various model modifications and optimizations.\n        \n        Modifications include:\n            1. Liger kernel optimization\n            2. Monkey patch (for remove padding and Ulysses SP)\n            3. Ensure correct dtype for all parameters\n            4. Gradient checkpointing\n            5. LoRA adapter\n        \n        Args:\n            model: Model to modify\n            use_liger: Whether to apply Liger kernel\n            use_remove_padding: Whether to use remove padding\n            use_fused_kernels: Whether to use fused kernels\n            enable_gradient_checkpointing: Whether to enable gradient checkpointing\n            torch_dtype: Target torch dtype\n        \n        Returns:\n            Modified model\n        \"\"\"\n        is_embodied = self.config.model.model_type == \"embodied\"\n        \n        # Apply Liger kernel optimization\n        if use_liger:\n            from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance\n            _apply_liger_kernel_to_instance(model=model)\n            logger.info(\"Applied Liger kernel to model\")\n        \n        # Apply monkey patch for non-embodied models\n        if not is_embodied:\n            from siirl.models.transformers.monkey_patch import apply_monkey_patch\n            apply_monkey_patch(\n                model=model,\n                use_remove_padding=use_remove_padding,\n                ulysses_sp_size=self.ulysses_sequence_parallel_size,\n                use_fused_kernels=use_fused_kernels,\n            )\n        \n        # Ensure dtype consistency\n        model.to(torch_dtype)\n        \n        # Enable gradient checkpointing\n        if enable_gradient_checkpointing:\n            model.gradient_checkpointing_enable(\n                gradient_checkpointing_kwargs={\"use_reentrant\": False}\n            )\n            logger.info(\"Enabled gradient checkpointing\")\n        \n        # Apply LoRA adapter\n        if self._is_lora:\n            logger.info(\"Applying LoRA to model\")\n            model.enable_input_require_grads()\n            lora_config = {\n                \"task_type\": TaskType.CAUSAL_LM,\n                \"r\": self.config.model.lora_rank,\n                \"lora_alpha\": self.config.model.lora_alpha,\n                \"target_modules\": convert_to_regular_types(self.config.model.target_modules),\n                \"bias\": \"none\"\n            }\n            model = get_peft_model(model, LoraConfig(**lora_config))\n        \n        return model\n\n    def _setup_fsdp_wrapper(\n        self,\n        model: torch.nn.Module,\n        fsdp_config: FSDPArguments,\n        role: Role,\n        enable_activation_offload: bool,\n        enable_gradient_checkpointing: bool,\n    ):\n        \"\"\"\n        Wrap model with FSDP (Fully Sharded Data Parallel).\n        \n        Steps:\n            1. Configure mixed precision\n            2. Get wrap policy\n            3. Wrap model based on strategy (fsdp/fsdp2)\n            4. Apply activation offload if needed\n        \n        Args:\n            model: Model to wrap\n            fsdp_config: FSDP configuration\n            role: Role (Actor or RefPolicy)\n            enable_activation_offload: Whether to enable activation offload\n            enable_gradient_checkpointing: Whether gradient checkpointing is enabled\n        \n        Returns:\n            FSDP wrapped model\n        \"\"\"\n        from torch.distributed.fsdp import CPUOffload, MixedPrecision\n        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n        from siirl.utils.model_utils.torch_dtypes import PrecisionType\n        \n        is_embodied = self.config.model.model_type == \"embodied\"\n        \n        # Configure mixed precision\n        mixed_precision_config = fsdp_config.mixed_precision\n        if mixed_precision_config is not None:\n            param_dtype = PrecisionType.to_dtype(mixed_precision_config.param_dtype)\n            reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.reduce_dtype)\n            buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.buffer_dtype)\n        else:\n            param_dtype = torch.bfloat16\n            reduce_dtype = torch.float32\n            buffer_dtype = torch.float32\n        \n        mixed_precision = MixedPrecision(\n            param_dtype=param_dtype,\n            reduce_dtype=reduce_dtype,\n            buffer_dtype=buffer_dtype\n        )\n        \n        # Get wrap policy\n        auto_wrap_policy = get_fsdp_wrap_policy(\n            module=model,\n            config=fsdp_config.wrap_policy,\n            is_lora=self.config.model.lora_rank > 0\n        )\n        \n        # Special case: HFRollout with Gemma\n        if self._is_rollout and self.config.rollout.name == \"hf\" and not is_embodied:\n            auto_wrap_policy = None\n        \n        if self.rank == 0:\n            logger.info(f\"wrap_policy: {auto_wrap_policy}\")\n        \n        # Prepare FSDP mesh and strategy\n        fsdp_mesh = self.device_mesh\n        sharding_strategy = get_sharding_strategy(fsdp_mesh)\n        fsdp_strategy = self.config.actor.strategy\n        \n        # Wrap model based on FSDP strategy\n        if fsdp_strategy == \"fsdp\":\n            # FSDP v1\n            cpu_offload = None if role == Role.Actor else CPUOffload(offload_params=True)\n            model_fsdp = FSDP(\n                model,\n                cpu_offload=cpu_offload,\n                param_init_fn=init_fn,\n                use_orig_params=False,\n                auto_wrap_policy=auto_wrap_policy,\n                device_id=get_device_id(),\n                sharding_strategy=sharding_strategy,\n                mixed_precision=mixed_precision,\n                sync_module_states=True,\n                device_mesh=self.device_mesh,\n                forward_prefetch=False,\n            )\n        elif fsdp_strategy == \"fsdp2\":\n            # FSDP v2\n            assert CPUOffloadPolicy is not None, \\\n                \"PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)\"\n            \n            mp_policy = MixedPrecisionPolicy(\n                param_dtype=param_dtype,\n                reduce_dtype=reduce_dtype,\n                cast_forward_inputs=True\n            )\n            \n            if role == Role.Actor and fsdp_config.offload_policy:\n                cpu_offload = CPUOffloadPolicy(pin_memory=True)\n                self._is_offload_param = False\n                self._is_offload_optimizer = False\n            else:\n                cpu_offload = None if role == Role.Actor else CPUOffloadPolicy(pin_memory=True)\n            \n            fsdp_kwargs = {\n                \"mesh\": fsdp_mesh,\n                \"mp_policy\": mp_policy,\n                \"offload_policy\": cpu_offload,\n                \"reshard_after_forward\": fsdp_config.reshard_after_forward,\n            }\n            full_state = model.state_dict()\n            apply_fsdp2(model, fsdp_kwargs, fsdp_config)\n            fsdp2_load_full_state_dict(model, full_state, fsdp_mesh, cpu_offload)\n            model_fsdp = model\n        else:\n            raise NotImplementedError(f\"FSDP strategy '{fsdp_strategy}' not implemented\")\n        \n        # Apply activation offload\n        if enable_activation_offload:\n            enable_activation_offloading(\n                model_fsdp, fsdp_strategy, enable_gradient_checkpointing\n            )\n        \n        return model_fsdp\n\n    def _create_optimizer_and_scheduler(\n        self,\n        model: torch.nn.Module,\n        optim_config: Optional[OptimizerArguments],\n        role: Role,\n    ):\n        \"\"\"\n        Create optimizer and learning rate scheduler.\n        \n        Only creates when role is Actor and optim_config is provided.\n        \n        Args:\n            model: Model to create optimizer for\n            optim_config: Optimizer configuration\n            role: Role (Actor or RefPolicy)\n        \n        Returns:\n            Tuple of (optimizer, lr_scheduler), returns (None, None) if not needed\n        \"\"\"\n        if role != Role.Actor or optim_config is None:\n            return None, None\n        \n        from torch import optim\n        from siirl.utils.model_utils.torch_functional import (\n            get_constant_schedule_with_warmup,\n            get_cosine_schedule_with_warmup\n        )\n        \n        # Create optimizer\n        optimizer = optim.AdamW(\n            model.parameters(),\n            lr=optim_config.lr,\n            betas=optim_config.betas,\n            weight_decay=optim_config.weight_decay,\n        )\n        \n        # Calculate warmup steps\n        total_steps = optim_config.total_training_steps\n        num_warmup_steps = int(optim_config.lr_warmup_steps)\n        if num_warmup_steps < 0:\n            num_warmup_steps_ratio = optim_config.lr_warmup_steps_ratio\n            num_warmup_steps = int(num_warmup_steps_ratio * total_steps)\n        \n        if self.rank == 0:\n            logger.info(f\"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}\")\n        \n        # Create learning rate scheduler\n        warmup_style = optim_config.warmup_style\n        if warmup_style == \"constant\":\n            lr_scheduler = get_constant_schedule_with_warmup(\n                optimizer=optimizer,\n                num_warmup_steps=num_warmup_steps\n            )\n        elif warmup_style == \"cosine\":\n            min_lr_ratio = optim_config.min_lr_ratio if optim_config.min_lr_ratio else 0.0\n            num_cycles = optim_config.num_cycles\n            lr_scheduler = get_cosine_schedule_with_warmup(\n                optimizer=optimizer,\n                num_warmup_steps=num_warmup_steps,\n                num_training_steps=total_steps,\n                min_lr_ratio=min_lr_ratio,\n                num_cycles=num_cycles\n            )\n        else:\n            raise NotImplementedError(f\"Warmup style '{warmup_style}' is not supported\")\n        \n        return optimizer, lr_scheduler\n\n    def _build_rollout(self, trust_remote_code=False):\n        from siirl.utils.model_utils.model import get_generation_config\n\n        local_path = copy_to_local(self.config.model.path)\n        self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)\n\n        # TODO(sgm): support FSDP hybrid shard for larger model\n        rollout_name = self.config.rollout.name\n        if rollout_name == \"hf\":\n            if self.config.model.model_type == \"embodied\":\n                from siirl.engine.rollout.embodied_rollout import EmbodiedHFRollout\n                rollout = EmbodiedHFRollout(module=None, config=self.config)\n            else: \n                from siirl.engine.rollout import HFRollout\n                rollout = HFRollout(module=None, config=self.config)\n\n        elif rollout_name == \"vllm\":\n            from siirl.engine.rollout.vllm_rollout import vllm_mode, vLLMRollout\n\n            local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.use_shm)\n            lora_kwargs = {\"lora_kwargs\": {\"enable_lora\": True, \"max_loras\": 1, \"max_lora_rank\": self._lora_rank}} if self._is_lora else {}\n            # lora_kwargs = {}\n            if vllm_mode == \"customized\":\n                rollout = vLLMRollout(actor_module=self.actor_module_fsdp, config=self.config.rollout, tokenizer=self.tokenizer, model_hf_config=self.actor_model_config, trust_remote_code=trust_remote_code, **lora_kwargs)\n            elif vllm_mode == \"spmd\":\n                from siirl.engine.rollout.vllm_rollout import vLLMAsyncRollout\n\n                vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == \"sync\" else vLLMAsyncRollout\n                rollout = vllm_rollout_cls(model_path=local_path, config=self.config.rollout, tokenizer=self.tokenizer, model_hf_config=self.actor_model_config, trust_remote_code=trust_remote_code, **lora_kwargs)\n            else:\n                raise NotImplementedError(\"vllm_mode must be 'customized' or 'spmd'\")\n\n            if self.device_mesh.mesh.numel() == 1:\n                self.config.rollout.load_format = \"dummy_hf\"\n\n        elif rollout_name == \"sglang\":\n            from siirl.engine.rollout.sglang_rollout import SGLangRollout\n\n            # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to\n            # SGLang's model_runner would check CUDA device capability. However, due to siirl's setting,\n            # the main process of ray can not find any CUDA device, which would potentially lead to:\n            # \"RuntimeError: No CUDA GPUs are available\".\n            # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and\n            # we import it here use the abs path.\n            # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76\n            # from siirl.engine.sharding_manager.fsdp_sglang import MultiAgentFSDPSGLangShardingManager\n\n            local_path = copy_to_local(self.config.model.path)\n            rollout = SGLangRollout(\n                actor_module=local_path,\n                config=self.config.rollout,\n                tokenizer=self.tokenizer,\n                model_hf_config=self.actor_model_config,\n                processing_class=self.processor if self.processor is not None else self.tokenizer,\n                trust_remote_code=trust_remote_code,\n            )\n\n        else:\n            raise NotImplementedError(f\"Rollout name: {self.config.rollout.name} is not supported\")\n\n        return rollout, None\n\n    def init_model(self):\n        from siirl.engine.actor import DataParallelPPOActor, RobDataParallelPPOActor\n\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.external_lib)\n\n        override_model_config = self.config.model.override_config\n        use_remove_padding = self.config.model.use_remove_padding\n        use_fused_kernels = self.config.model.use_fused_kernels\n        use_shm = self.config.model.use_shm\n\n        tokenizer_module = load_tokenizer(model_args=self.config.model)\n        self.tokenizer, self.processor = tokenizer_module[\"tokenizer\"], tokenizer_module[\"processor\"]\n\n        if self._is_actor:\n            optim_config = self.config.actor.optim\n            fsdp_config = self.config.actor.fsdp_config\n\n            local_path = copy_to_local(self.config.model.path, use_shm=use_shm)\n            (\n                self.actor_module_fsdp,\n                self.actor_optimizer,\n                self.actor_lr_scheduler,\n                self.actor_model_config,\n            ) = self._build_model_optimizer(\n                model_path=local_path,\n                fsdp_config=fsdp_config,\n                optim_config=optim_config,\n                override_model_config=override_model_config,\n                use_remove_padding=use_remove_padding,\n                use_fused_kernels=use_fused_kernels,\n                enable_gradient_checkpointing=self.config.model.enable_gradient_checkpointing,\n                trust_remote_code=self.config.model.trust_remote_code,\n                use_liger=self.config.model.use_liger,\n                role=Role.Actor,\n                enable_activation_offload=self.config.model.enable_activation_offload,\n            )\n\n            # get the original unwrapped module\n            if fsdp_version(self.actor_module_fsdp) == 1:\n                self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module\n\n            if self._is_offload_param:\n                offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n\n            if self._is_offload_optimizer:\n                offload_fsdp_optimizer(optimizer=self.actor_optimizer)\n        # load from checkpoint\n        if self._is_actor:\n            self.config.actor.use_remove_padding = use_remove_padding\n            self.config.actor.use_fused_kernels = use_fused_kernels\n            \n            # Select appropriate Actor class based on model type and pass embodied parameters\n            is_embodied_model = self.config.model.model_type == \"embodied\"\n            if is_embodied_model:\n                self.config.actor.embodied_type = self.config.embodied.embodied_type\n                self.config.actor.action_token_len = self.config.embodied.action_token_len\n                self.config.actor.action_chunks_len = self.config.embodied.action_chunks_len\n                \n                from siirl.engine.actor.embodied_actor import RobDataParallelPPOActor\n                ActorClass = RobDataParallelPPOActor\n            else:\n                ActorClass = DataParallelPPOActor\n            \n            self.actor = ActorClass(\n                config=self.config.actor,\n                actor_module=self.actor_module_fsdp,\n                actor_optimizer=self.actor_optimizer\n            )\n\n        if self._is_rollout:\n            from transformers import AutoConfig\n\n            local_path = copy_to_local(self.config.model.path)\n            self.actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=self.config.model.trust_remote_code)\n            self.flops_counter = FlopsCounter(self.actor_model_config, forward_only=True)\n            self.rollout, self.rollout_sharding_manager = self._build_rollout(trust_remote_code=self.config.model.trust_remote_code)\n\n        if self._is_ref:\n            local_path = copy_to_local(self.config.model.path, use_shm=use_shm)\n            self.ref_module_fsdp = self._build_model_optimizer(\n                model_path=local_path,\n                fsdp_config=self.config.ref.fsdp_config,\n                optim_config=None,\n                override_model_config=override_model_config,\n                use_remove_padding=use_remove_padding,\n                use_fused_kernels=use_fused_kernels,\n                trust_remote_code=self.config.model.trust_remote_code,\n                use_liger=self.config.model.use_liger,\n                role=Role.RefPolicy,\n            )[0]\n            self.config.ref.use_remove_padding = use_remove_padding\n            self.config.ref.use_fused_kernels = use_fused_kernels\n            \n            # Pass embodied parameters to RefPolicy for embodied models and select class\n            is_embodied_model = self.config.model.model_type == \"embodied\"\n            if is_embodied_model:\n                self.config.ref.embodied_type = self.config.embodied.embodied_type\n                self.config.ref.action_token_len = self.config.embodied.action_token_len\n                self.config.ref.action_chunks_len = self.config.embodied.action_chunks_len\n                \n                from siirl.engine.actor.embodied_actor import RobDataParallelPPOActor\n                RefPolicyClass = RobDataParallelPPOActor\n            else:\n                RefPolicyClass = DataParallelPPOActor\n            \n            self.ref_policy = RefPolicyClass(\n                config=self.config.ref,\n                actor_module=self.ref_module_fsdp\n            )\n\n        if self._is_actor:\n            self.flops_counter = FlopsCounter(self.actor_model_config)\n            self.checkpoint_manager = FSDPCheckpointManager(\n                model=self.actor_module_fsdp, optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_contents=self.config.actor.checkpoint.contents, tokenizer=self.tokenizer\n            )\n\n    def update_actor(self, data: TensorDict):\n        # Support all hardwares\n        data = data.to(get_device_id())\n\n        assert self._is_actor\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.actor_module_fsdp)\n        if self._is_offload_optimizer:\n            load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id())\n\n        with self.ulysses_sharding_manager:\n            data = self.ulysses_sharding_manager.preprocess_data(data=data)\n            # perform training\n            with Timer(name=\"update_policy\", logger=None) as timer:\n                metrics = self.actor.update_policy(data=data)\n            delta_time = timer.last\n            global_num_tokens = data[\"global_token_num\"]\n            estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n            metrics[\"perf/mfu/actor\"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops\n            metrics[\"perf/delta_time/actor\"] = delta_time\n            metrics[\"perf/max_memory_allocated_gb\"] = get_torch_device().max_memory_allocated() / (1024**3)\n            metrics[\"perf/max_memory_reserved_gb\"] = get_torch_device().max_memory_reserved() / (1024**3)\n            metrics[\"perf/cpu_memory_used_gb\"] = psutil.virtual_memory().used / (1024**3)\n\n            lr = self.actor_lr_scheduler.get_last_lr()[0]\n            metrics[\"actor/lr\"] = lr\n            self.actor_lr_scheduler.step()\n\n            # TODO: here, we should return all metrics\n            data[\"metrics\"] = NonTensorData(metrics)\n            processed_data = self.ulysses_sharding_manager.postprocess_data(data=data)\n            processed_data = processed_data.to(\"cpu\")\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.actor_optimizer)\n\n        return processed_data\n\n    def generate_sequences(self, prompts: TensorDict):\n        # Support all hardwares\n        prompts = prompts.to(get_device_id())\n        assert self._is_rollout\n        prompts[\"eos_token_id\"] = NonTensorData(self.generation_config.eos_token_id if self.generation_config is not None else self.tokenizer.eos_token_id)\n        prompts[\"pad_token_id\"] = NonTensorData(self.generation_config.pad_token_id if self.generation_config is not None else self.tokenizer.pad_token_id)\n        \n        with self.rollout_sharding_manager:\n            if self.config.rollout.name == \"sglang_async\":\n                from siirl.engine.rollout.sglang_rollout import AsyncSGLangRollout\n\n                if isinstance(self.rollout, AsyncSGLangRollout) and hasattr(self.rollout, \"_tool_schemas\") and len(self.rollout._tool_schemas) > 0:\n                    output = self.rollout.generate_sequences_with_tools(prompts=prompts)\n                else:\n                    output = self.rollout.generate_sequences(prompts=prompts)\n            else:\n                with Timer(name=\"generate_sequences\", logger=None) as timer:\n                    output = self.rollout.generate_sequences(prompts=prompts)\n                total_input_tokens = output[\"total_input_tokens\"] if \"total_input_tokens\" in output else 0\n                total_output_tokens = output[\"total_output_tokens\"] if \"total_output_tokens\" in output else 0\n                delta_time = timer.last\n\n                # Calculate correct batch_seqlens for MFU computation\n                # Get batch size from prompts\n                batch_size = prompts[\"input_ids\"].shape[0] if \"input_ids\" in prompts else 1\n                # Calculate average sequence length per sample (prompt + response)\n                avg_seq_len = (total_input_tokens + total_output_tokens) / batch_size if batch_size > 0 else 0\n                # Create batch_seqlens list with each sample's average length\n                batch_seqlens = [int(avg_seq_len)] * batch_size\n\n                estimated_flops, promised_flops = self.flops_counter.estimate_flops(batch_seqlens, delta_time)\n                metrics = {}\n                # MFU should not be divided by TP size - it's already per-GPU\n                metrics[\"perf/mfu/rollout\"] = estimated_flops / promised_flops\n                metrics[\"perf/delta_time/rollout\"] = delta_time\n                output[\"metrics\"] = NonTensorData(metrics, batch_size=None)\n\n\n        output = output.to(\"cpu\")\n\n        # clear kv cache\n        get_torch_device().empty_cache()\n        return output\n\n    def compute_log_prob(self, data: TensorDict):\n        # when is_lora is True, we use the actor without lora applied to calculate the log_prob\n        # which is mostly used for ref log_prob calculation\n        assert self._is_actor\n        \n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.actor_module_fsdp)\n\n        # Support all hardwares\n        from contextlib import nullcontext\n        is_lora = data.pop(\"is_lora\", False)\n        adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext()\n        \n        data = data.to(get_device_id())\n        \n        # we should always recompute old_log_probs when it is HybridEngine\n        data[\"micro_batch_size\"] = NonTensorData(self.config.rollout.log_prob_micro_batch_size_per_gpu)  \n        data[\"max_token_len\"] = NonTensorData(self.config.rollout.log_prob_max_token_len_per_gpu)\n        data[\"use_dynamic_bsz\"] = NonTensorData(self.config.rollout.log_prob_use_dynamic_bsz)\n        data[\"temperature\"] = NonTensorData(self.config.rollout.temperature)\n        data[\"pad_token_id\"] = NonTensorData(self.tokenizer.pad_token_id)\n\n        # perform recompute log_prob\n        with self.ulysses_sharding_manager:\n            data = self.ulysses_sharding_manager.preprocess_data(data)\n            with Timer(name=\"compute_actor_log_prob\", logger=None) as timer, adapter_ctx:\n                output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)\n            delta_time = timer.last\n            global_num_tokens = data[\"global_token_num\"]\n            estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n            metrics = {\n                # actor forward\n                \"perf/mfu/actor_log_prob\": estimated_flops / promised_flops / 3,\n                \"perf/delta_time/actor_log_prob\": delta_time,\n            }\n            data[\"old_log_probs\"] = output\n            if entropys is not None:\n                data[\"entropys\"] = entropys \n            data[\"metrics\"] = NonTensorData(metrics)\n            processed_data = self.ulysses_sharding_manager.postprocess_data(data)\n\n        processed_data = processed_data.to(\"cpu\")\n\n        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes\n        # unshard the root FSDP module\n        if self.group_world_size > 1 and fsdp_version(self.actor.actor_module) == 1:\n            self.actor.actor_module._handle.reshard(True)\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n\n        return processed_data\n\n    def compute_ref_log_prob(self, data: TensorDict):    \n        if self._is_lora:\n            # if _is_lora, actor without lora applied is the ref\n            data[\"is_lora\"] = NonTensorData(True)\n            data = self.compute_log_prob(data)\n            # this old_log_probs is in fact ref_log_prob\n            data = TensorDict({\"ref_log_prob\": data[\"old_log_probs\"]})\n            return data\n        assert self._is_ref\n        # else:\n        # otherwise, the class have a standalone ref model\n        # Support all hardwares\n        data = data.to(get_device_id())\n\n        metrics = {}\n        micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu\n        data[\"micro_batch_size\"] = NonTensorData(micro_batch_size)\n        data[\"temperature\"] = NonTensorData(self.config.rollout.temperature)\n        data[\"max_token_len\"] = NonTensorData(self.config.ref.log_prob_max_token_len_per_gpu)\n        data[\"use_dynamic_bsz\"] = NonTensorData(self.config.ref.log_prob_use_dynamic_bsz)\n        with self.ulysses_sharding_manager:\n            data = self.ulysses_sharding_manager.preprocess_data(data)\n            with Timer(name=\"compute_log_prob\", logger=None) as timer:\n                output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False)\n            delta_time = timer.last\n            global_num_tokens = data[\"global_token_num\"]\n            estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n            metrics[\"perf/mfu/ref\"] = estimated_flops / promised_flops\n            metrics[\"perf/delta_time/ref\"] = delta_time\n            data[\"ref_log_prob\"] = output\n            data[\"metrics\"] = NonTensorData(metrics)\n            processed_data = self.ulysses_sharding_manager.postprocess_data(data)\n\n        processed_data = processed_data.to(\"cpu\")\n\n        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes\n        # unshard the root FSDP module\n        if self.group_world_size > 1 and fsdp_version(self.ref_policy.actor_module) == 1:\n            self.ref_policy.actor_module._handle.reshard(True)\n\n        return processed_data\n\n    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        # only support save and load ckpt for actor\n        assert self._is_actor\n        import torch\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.actor_module_fsdp)\n\n        self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep)\n\n        torch.distributed.barrier()\n\n        if self._is_lora and hasattr(getattr(self, \"actor_module\", self.actor_module_fsdp), \"peft_config\"):\n            lora_save_path = os.path.join(local_path, \"lora_adapter\")\n            peft_model = getattr(self, \"actor_module\", self.actor_module_fsdp)\n            peft_config = {}\n            if torch.distributed.get_rank() == 0:\n                os.makedirs(lora_save_path, exist_ok=True)\n                peft_config = asdict(peft_model.peft_config.get(\"default\", {}))\n                peft_config[\"task_type\"] = peft_config[\"task_type\"].value\n                peft_config[\"peft_type\"] = peft_config[\"peft_type\"].value\n                peft_config[\"target_modules\"] = list(peft_config[\"target_modules\"])\n            try:\n                if fsdp_version(self.actor_module_fsdp) > 0:\n                    self.actor_module_fsdp = self.actor_module_fsdp.to(get_device_name())\n                    lora_params = layered_summon_lora_params(self.actor_module_fsdp)\n                    if torch.distributed.get_rank() == 0:\n                        save_file(lora_params, os.path.join(lora_save_path, \"adapter_model.safetensors\"))\n                        with open(os.path.join(lora_save_path, \"adapter_config.json\"), \"w\", encoding=\"utf-8\") as f:\n                            json.dump(peft_config, f, ensure_ascii=False, indent=4)\n            except Exception as e:\n                if torch.distributed.get_rank() == 0:\n                    logger.info(f\"[rank-{self.rank}]: Save LoRA Adapter Error ({e})\")\n\n            torch.distributed.barrier()\n            if torch.distributed.get_rank() == 0:\n                logger.info(f\"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}\")\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n\n    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False):\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.actor_module_fsdp)\n\n        self.checkpoint_manager.load_checkpoint(local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load)\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.actor_module_fsdp)\n\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(self.actor_optimizer)\n\n    def set_rollout_sharding_manager(self, sharding_manager):\n        self.rollout_sharding_manager = sharding_manager\n        self.rollout.sharding_manager = sharding_manager\n\n\nclass CriticWorker(Worker):\n    def __init__(self, config: CriticArguments, process_group: ProcessGroup):\n        super().__init__()\n        import torch.distributed\n\n        if not torch.distributed.is_initialized():\n            torch.distributed.init_process_group(backend=get_nccl_backend())\n        self.config = config\n        world_size = torch.distributed.get_world_size(group=process_group)\n        self.group_world_size = world_size\n        # build device mesh for Ulysses Sequence Parallel\n        self.device_mesh = create_device_mesh_from_group(process_group=process_group, fsdp_size=self.config.model.fsdp_config.fsdp_size)\n\n        self.ulysses_device_mesh = None\n        self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size\n        if self.ulysses_sequence_parallel_size > 1:\n            self.ulysses_device_mesh = create_device_mesh_from_group(process_group=process_group, sp_size=self.ulysses_sequence_parallel_size)\n\n        self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)\n\n        # set FSDP offload params\n        self._is_offload_param = self.config.model.fsdp_config.param_offload\n        self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload\n\n        # normalize config\n        self.config.ppo_mini_batch_size *= self.config.rollout_n\n        self.config.ppo_mini_batch_size //= world_size // self.ulysses_sequence_parallel_size\n        if self.config.ppo_micro_batch_size is not None:\n            self.config.ppo_micro_batch_size //= world_size // self.ulysses_sequence_parallel_size\n            self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size\n\n        if self.config.ppo_micro_batch_size_per_gpu is not None:\n            assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, f\"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}\"\n            assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, f\"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}\"\n        self._is_lora = self.config.model.lora_rank > 0\n\n    def _build_critic_model_optimizer(self, config):\n        # the following line is necessary\n        from torch import optim\n        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n        from torch.distributed.fsdp import MixedPrecision\n\n        from siirl.utils.model_utils.model import print_model_size\n        from siirl.utils.model_utils.torch_dtypes import PrecisionType\n\n        use_shm = config.model.use_shm\n        local_path = copy_to_local(config.model.path, use_shm=use_shm)\n        # note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info\n        # using random initialized model from any architecture. May not be the same as Actor.\n\n        tokenizer_module = load_tokenizer(model_args=self.config.model)\n        self.tokenizer, self.processor = tokenizer_module[\"tokenizer\"], tokenizer_module[\"processor\"]\n\n        override_config = self.config.model.override_config\n        override_config_kwargs = {\n            \"bos_token_id\": self.tokenizer.bos_token_id,\n            \"eos_token_id\": self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.tokenizer.pad_token_id,\n        }\n        override_config_kwargs.update(override_config)\n        # if self.rank == 0:\n        #     logger.info(f\"Critic overriding config {override_config_kwargs}\")\n\n        torch_dtype = self.config.model.fsdp_config.model_dtype\n        torch_dtype = PrecisionType.to_dtype(torch_dtype)\n\n        from transformers import AutoConfig, AutoModelForTokenClassification\n\n        critic_model_config = AutoConfig.from_pretrained(local_path, attn_implementation=\"flash_attention_2\", trust_remote_code=config.model.trust_remote_code)\n        critic_model_config.num_labels = 1\n        # patch for kimi-vl\n        if getattr(critic_model_config, \"model_type\", None) == \"kimi_vl\":\n            critic_model_config.text_config.topk_method = \"greedy\"\n\n        init_context = get_init_weight_context_manager(use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh)\n\n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            critic_model_config.classifier_dropout = 0.0\n            critic_model_config.hidden_dropout = \"0\"\n            critic_module = AutoModelForTokenClassification.from_pretrained(\n                pretrained_model_name_or_path=local_path,\n                torch_dtype=torch_dtype,\n                config=critic_model_config,\n                trust_remote_code=config.model.trust_remote_code,\n            )\n\n            use_remove_padding = config.model.use_remove_padding\n\n            # Apply monkey patch for performance optimizations\n            from siirl.models.transformers.monkey_patch import apply_monkey_patch\n            apply_monkey_patch(\n                model=critic_module,\n                use_remove_padding=use_remove_padding,\n                ulysses_sp_size=self.ulysses_sequence_parallel_size,\n            )\n\n            # some parameters may not in torch_dtype\n            critic_module.to(torch_dtype)\n\n            if config.model.enable_gradient_checkpointing:\n                critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n\n        if self._is_lora:\n            logger.info(\"Applying LoRA to critic module\")\n            critic_module.enable_input_require_grads()\n            # Convert config to regular Python types before creating PEFT model\n            lora_config = {\n                \"task_type\": TaskType.CAUSAL_LM,\n                \"r\": self.config.model.lora_rank,\n                \"lora_alpha\": self.config.model.lora_alpha,\n                \"target_modules\": convert_to_regular_types(self.config.model.target_modules),\n                \"bias\": \"none\",\n            }\n            critic_module = get_peft_model(critic_module, LoraConfig(**lora_config))\n\n        if self.rank == 0:\n            print_model_size(critic_module)\n\n        self.critic_model_config = critic_model_config\n\n        fsdp_config = self.config.model.fsdp_config\n        mixed_precision_config = fsdp_config.mixed_precision\n        if mixed_precision_config is not None:\n            param_dtype = PrecisionType.to_dtype(mixed_precision_config.param_dtype)\n            reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.reduce_dtype)\n            buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.buffer_dtype)\n        else:\n            param_dtype = torch.bfloat16\n            reduce_dtype = torch.float32\n            buffer_dtype = torch.float32\n\n        mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)\n\n        auto_wrap_policy = get_fsdp_wrap_policy(module=critic_module, config=self.config.model.fsdp_config.wrap_policy, is_lora=self.config.model.lora_rank > 0)\n\n        fsdp_mesh = self.device_mesh\n        sharding_strategy = get_sharding_strategy(fsdp_mesh)\n\n        # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation\n        if config.strategy == \"fsdp\":\n            critic_module = FSDP(\n                critic_module,\n                param_init_fn=init_fn,\n                use_orig_params=False,\n                auto_wrap_policy=auto_wrap_policy,\n                device_id=get_device_id(),\n                sharding_strategy=sharding_strategy,\n                mixed_precision=mixed_precision,\n                sync_module_states=True,\n                forward_prefetch=False,\n                device_mesh=self.device_mesh,\n                cpu_offload=None,\n            )\n        elif config.strategy == \"fsdp2\":\n            assert CPUOffloadPolicy is not None, \"PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)\"\n            mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True)\n            offload_policy = None\n            if fsdp_config.offload_policy:\n                self._is_offload_param = False\n                self._is_offload_optimizer = False\n                offload_policy = CPUOffloadPolicy(pin_memory=True)\n\n            fsdp_kwargs = {\n                \"mesh\": fsdp_mesh,\n                \"mp_policy\": mp_policy,\n                \"offload_policy\": offload_policy,\n                \"reshard_after_forward\": fsdp_config.reshard_after_forward,\n            }\n            full_state = critic_module.state_dict()\n            apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config)\n            fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy)\n        else:\n            raise NotImplementedError(f\"Unknown strategy {config.strategy}\")\n\n        if config.model.enable_activation_offload:\n            enable_gradient_checkpointing = config.model.enable_gradient_checkpointing\n            enable_activation_offloading(critic_module, config.strategy, enable_gradient_checkpointing)\n\n        critic_optimizer = optim.AdamW(\n            critic_module.parameters(),\n            lr=config.optim.lr,\n            betas=config.optim.betas,\n            weight_decay=config.optim.weight_decay,\n        )\n\n        total_steps = config.optim.total_training_steps\n        num_warmup_steps = int(config.optim.lr_warmup_steps)\n        warmup_style = config.optim.warmup_style\n        if num_warmup_steps < 0:\n            num_warmup_steps_ratio = config.optim.lr_warmup_steps_ratio\n            num_warmup_steps = int(num_warmup_steps_ratio * total_steps)\n\n        if self.rank == 0:\n            logger.info(f\"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}\")\n\n        from siirl.utils.model_utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup\n\n        if warmup_style == \"constant\":\n            critic_lr_scheduler = get_constant_schedule_with_warmup(optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps)\n        elif warmup_style == \"cosine\":\n            critic_lr_scheduler = get_cosine_schedule_with_warmup(optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps)\n        else:\n            raise NotImplementedError(f\"Warmup style {warmup_style} is not supported\")\n\n        return critic_module, critic_optimizer, critic_lr_scheduler\n\n    def init_model(self):\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.external_lib)\n\n        from siirl.engine.critic import DataParallelPPOCritic\n\n        self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer(self.config)\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.critic_module)\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.critic_optimizer)\n\n        self.critic = DataParallelPPOCritic(config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer)\n\n        self.flops_counter = FlopsCounter(self.critic_model_config)\n        self.checkpoint_manager = FSDPCheckpointManager(\n            model=self.critic_module, optimizer=self.critic_optimizer, lr_scheduler=self.critic_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_contents=self.config.checkpoint.contents, tokenizer=self.tokenizer\n        )\n\n    def compute_values(self, data: TensorDict):\n        # Support all hardwares\n        data = data.to(get_device_id())\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.critic_module)\n        micro_batch_size = self.config.ppo_micro_batch_size_per_gpu\n        data[\"micro_batch_size\"] = NonTensorData(micro_batch_size)\n        data[\"max_token_len\"] = NonTensorData(self.config.forward_max_token_len_per_gpu)\n        data[\"use_dynamic_bsz\"] = NonTensorData(self.config.use_dynamic_bsz)\n        # perform forward computation\n        with self.ulysses_sharding_manager:\n            data = self.ulysses_sharding_manager.preprocess_data(data=data)\n            values = self.critic.compute_values(data=data)\n            data[\"values\"] = values\n            processed_data = self.ulysses_sharding_manager.postprocess_data(data=data)\n\n        processed_data = processed_data.to(\"cpu\")\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.critic_module)\n        return processed_data\n\n    def update_critic(self, data: TensorDict):\n        # Support all hardwares\n        data = data.to(get_device_id())\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.critic_module)\n        if self._is_offload_optimizer:\n            load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_device_id())\n\n        # perform forward computation\n        with self.ulysses_sharding_manager:\n            data = self.ulysses_sharding_manager.preprocess_data(data=data)\n\n            with Timer(name=\"update_critic\", logger=None) as timer:\n                metrics = self.critic.update_critic(data=data)\n            delta_time = timer.last\n\n            global_num_tokens = data[\"global_token_num\"]\n            estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n            metrics[\"perf/mfu/critic\"] = estimated_flops * self.config.ppo_epochs / promised_flops\n\n            self.critic_lr_scheduler.step()\n            lr = self.critic_lr_scheduler.get_last_lr()[0]\n            metrics[\"critic/lr\"] = lr\n\n            data[\"metrics\"] = NonTensorData(metrics)\n            processed_data = self.ulysses_sharding_manager.postprocess_data(data=data)\n\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.critic_module)\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(optimizer=self.critic_optimizer)\n\n        processed_data = processed_data.to(\"cpu\")\n        return processed_data\n\n    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        import torch\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.critic_module)\n\n        self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep)\n\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.critic_module)\n\n    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):\n        import torch\n\n        if self._is_offload_param:\n            load_fsdp_model_to_gpu(self.critic_module)\n\n        self.checkpoint_manager.load_checkpoint(local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load)\n\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_fsdp_model_to_cpu(self.critic_module)\n\n        if self._is_offload_optimizer:\n            offload_fsdp_optimizer(self.critic_optimizer)\n\n\n# TODO(sgm): we may need to extract it to dp_reward_model.py\nclass RewardModelWorker(Worker):\n    \"\"\"\n    Note that we only implement the reward model that is subclass of AutoModelForTokenClassification.\n    \"\"\"\n\n    def __init__(self, config: RewardModelArguments, process_group: ProcessGroup):\n        super().__init__()\n        import torch.distributed\n\n        if not torch.distributed.is_initialized():\n            torch.distributed.init_process_group(backend=get_nccl_backend())\n        self.config = config\n\n        # build device mesh for Ulysses Sequence Parallel\n        world_size = torch.distributed.get_world_size(group=process_group)\n\n        self.device_mesh = create_device_mesh_from_group(process_group=process_group, fsdp_size=self.config.model.fsdp_config.fsdp_size)\n\n        self.ulysses_device_mesh = None\n        self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size\n        if self.ulysses_sequence_parallel_size > 1:\n            self.ulysses_device_mesh = create_device_mesh_from_group(process_group=process_group, sp_size=self.ulysses_sequence_parallel_size)\n\n        self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)\n\n        self.use_remove_padding = self.config.model.use_remove_padding\n\n        # normalize config\n        if self.config.micro_batch_size is not None:\n            self.config.micro_batch_size //= world_size\n            self.config.micro_batch_size_per_gpu = self.config.micro_batch_size\n\n    def _build_model(self, config: RewardModelArguments):\n        # the following line is necessary\n        from torch.distributed.fsdp import CPUOffload\n        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n        from transformers import AutoConfig, AutoModelForTokenClassification\n\n        use_shm = config.model.use_shm\n        # download the checkpoint from hdfs\n        local_path = copy_to_local(config.model.path, use_shm=use_shm)\n\n        if self.config.model.input_tokenizer is None:\n            self._do_switch_chat_template = False\n        else:\n            self._do_switch_chat_template = True\n            input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer)\n            input_tokenizer_module = load_tokenizer(path=input_tokenizer_local_path)\n            self.input_tokenizer = input_tokenizer_module[\"tokenizer\"]\n            tokenizer_module = load_tokenizer(model_args=self.config.model)\n            self.tokenizer, self.processor = tokenizer_module[\"tokenizer\"], tokenizer_module[\"processor\"]\n\n        trust_remote_code = config.model.trust_remote_code\n        model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)\n        model_config.num_labels = 1\n\n        # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect\n        init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh)\n\n        with init_context(), warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            model_config.classifier_dropout = 0.0\n            reward_module = AutoModelForTokenClassification.from_pretrained(\n                pretrained_model_name_or_path=local_path,\n                config=model_config,\n                torch_dtype=torch.bfloat16,\n                attn_implementation=\"flash_attention_2\",\n                trust_remote_code=trust_remote_code,\n            )            \n            # Apply monkey patch for performance optimizations\n            from siirl.models.transformers.monkey_patch import apply_monkey_patch\n            apply_monkey_patch(\n                model=reward_module,\n                use_remove_padding=config.model.use_remove_padding,\n                ulysses_sp_size=self.ulysses_sequence_parallel_size,\n            )\n\n            reward_module.to(torch.bfloat16)\n\n        auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config.wrap_policy)\n\n        fsdp_mesh = self.device_mesh\n        sharding_strategy = get_sharding_strategy(fsdp_mesh)\n\n        if config.strategy == \"fsdp\":\n            reward_module = FSDP(\n                reward_module,\n                param_init_fn=init_fn,\n                use_orig_params=False,\n                auto_wrap_policy=auto_wrap_policy,\n                device_id=get_device_id(),\n                sharding_strategy=sharding_strategy,  # zero3\n                sync_module_states=True,\n                cpu_offload=CPUOffload(offload_params=True),\n                forward_prefetch=False,\n                device_mesh=self.device_mesh,\n            )\n        elif config.strategy == \"fsdp2\":\n            assert CPUOffloadPolicy is not None, \"PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)\"\n            cpu_offload = CPUOffloadPolicy(pin_memory=True)\n            fsdp_kwargs = {\n                \"mesh\": fsdp_mesh,\n                \"offload_policy\": cpu_offload,\n                \"reshard_after_forward\": config.model.fsdp_config.reshard_after_forward,\n            }\n            full_state = reward_module.state_dict()\n            apply_fsdp2(reward_module, fsdp_kwargs, config.model.fsdp_config)\n            fsdp2_load_full_state_dict(reward_module, full_state, fsdp_mesh, cpu_offload)\n        else:\n            raise NotImplementedError(f\"Unknown strategy: {config.strategy}\")\n        return reward_module\n\n    def init_model(self):\n        # This is used to import external_lib into the huggingface systems\n        import_external_libs(self.config.model.external_lib)\n        self.reward_module = self._build_model(config=self.config)\n\n    def _forward_micro_batch(self, micro_batch):\n        if is_cuda_available:\n            from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input\n        elif is_npu_available:\n            from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input\n\n        from siirl.utils.model_utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs\n\n        with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16):\n            input_ids = micro_batch[\"input_ids\"]\n            batch_size, seqlen = input_ids.shape\n            attention_mask = micro_batch[\"attention_mask\"]\n            position_ids = micro_batch[\"position_ids\"]\n\n            if self.use_remove_padding:\n                input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask)  # input_ids_rmpad (total_nnz, ...)\n                input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)\n\n                # unpad the position_ids to align the rotary\n                position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), \"b s ... -> (b s) ...\"), indices).transpose(0, 1)\n\n                # pad and slice the inputs if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size)\n\n                # only pass input_ids and position_ids to enable flash_attn_varlen\n                output = self.reward_module(input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False)  # prevent model thinks we are generating\n                reward_rmpad = output.logits\n                reward_rmpad = reward_rmpad.squeeze(0)  # (total_nnz)\n\n                # gather output if sp > 1\n                if self.ulysses_sequence_parallel_size > 1:\n                    reward_rmpad = gather_outpus_and_unpad(reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size)\n\n                # pad it back\n                rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)\n            else:\n                output = self.reward_module(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False)\n                rm_score = output.logits  # (batch_size, seq_len, 1)\n                rm_score = rm_score.squeeze(-1)\n\n            # extract the result of the last valid token\n            eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)  # (bsz,)\n            rm_score = rm_score[torch.arange(batch_size), eos_mask_idx]\n            return rm_score\n\n    def _expand_to_token_level(self, data: TensorDict, scores: torch.Tensor):\n        batch_size = data.batch.batch_size[0]\n        # expand as token_level_reward\n        attention_mask = data.batch[\"attention_mask\"]\n        position_ids = data.batch[\"position_ids\"]\n        response_length = data.batch[\"responses\"].shape[-1]\n        eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)  # (bsz,)\n        token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype)  # (bsz, seqlen)\n        token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores\n\n        # select the response part\n        token_level_scores = token_level_scores[:, -response_length:]\n\n        return token_level_scores\n\n    def _switch_chat_template(self, data: TensorDict):\n        src_max_length = data.batch[\"attention_mask\"].shape[-1]\n\n        src_tokenizer = self.input_tokenizer\n        target_tokenizer = self.tokenizer\n\n        rm_input_ids = []\n        rm_attention_mask = []\n\n        for i in range(data.batch.batch_size[0]):\n            # extract raw prompt\n            if isinstance(data.non_tensor_batch[\"raw_prompt\"][i], list):\n                chat: list = data.non_tensor_batch[\"raw_prompt\"][i]\n            else:\n                chat: list = data.non_tensor_batch[\"raw_prompt\"][i].tolist()\n\n            # extract response\n            response_ids = data.batch[\"responses\"][i]\n            response_length = response_ids.shape[-1]\n            valid_response_length = data.batch[\"attention_mask\"][i][-response_length:].sum()\n            valid_response_ids = response_ids[:valid_response_length]\n\n            # decode\n            response = src_tokenizer.decode(valid_response_ids)\n            # remove bos and eos\n            response = response.replace(src_tokenizer.eos_token, \"\")\n\n            chat.append({\"role\": \"assistant\", \"content\": response})\n\n            prompt_with_chat_template = target_tokenizer.apply_chat_template(chat, add_generation_prompt=False, tokenize=False)\n            if self.rank == 0 and i == 0:\n                # for debugging purpose\n                logger.info(f\"Switch template. chat: {prompt_with_chat_template}\")\n\n            # the maximum length is actually determined by the reward model itself\n            max_length = self.config.max_length\n            if max_length is None:\n                max_length = src_max_length\n\n            model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors=\"pt\", add_special_tokens=False)\n            input_ids, attention_mask = F.postprocess_data(\n                input_ids=model_inputs[\"input_ids\"],\n                attention_mask=model_inputs[\"attention_mask\"],\n                max_length=max_length,\n                pad_token_id=target_tokenizer.pad_token_id,\n                left_pad=False,  # right padding\n                truncation=self.config.truncation,\n            )  # truncate from the right\n\n            rm_input_ids.append(input_ids)\n            rm_attention_mask.append(attention_mask)\n\n        rm_input_ids = torch.cat(rm_input_ids, dim=0)\n        rm_attention_mask = torch.cat(rm_attention_mask, dim=0)\n\n        rm_position_ids = compute_position_id_with_mask(rm_attention_mask)\n\n        rm_inputs = {\"input_ids\": rm_input_ids, \"attention_mask\": rm_attention_mask, \"position_ids\": rm_position_ids}\n\n        return TensorDict.from_dict(rm_inputs)\n\n    def compute_rm_score(self, data: TensorDict):\n        import itertools\n\n        from siirl.utils.model_utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\n\n        # Support all hardwares\n        data = data.to(get_device_id())\n        if self._do_switch_chat_template:\n            rm_data = self._switch_chat_template(data)\n        else:\n            rm_input_ids = data.batch[\"input_ids\"]\n            rm_attention_mask = data.batch[\"attention_mask\"]\n            rm_position_ids = data.batch[\"position_ids\"]\n            rm_inputs = {\n                \"input_ids\": rm_input_ids,\n                \"attention_mask\": rm_attention_mask,\n                \"position_ids\": rm_position_ids,\n            }\n            rm_data = TensorDict.from_dict(rm_inputs)\n\n        # Support all hardwares\n        rm_data.batch = rm_data.batch.to(get_device_id())\n\n        # perform forward computation\n        with self.ulysses_sharding_manager:\n            rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data)\n            data = self.ulysses_sharding_manager.preprocess_data(data=data)\n\n            use_dynamic_bsz = self.config.use_dynamic_bsz\n            if use_dynamic_bsz:\n                max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size\n                micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len)\n            else:\n                micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu)\n            output = []\n            for micro_batch in micro_batches:\n                rm_score = self._forward_micro_batch(micro_batch)\n                output.append(rm_score)\n            scores = torch.cat(output, dim=0)  # (batch_size)\n\n            if use_dynamic_bsz:\n                indices = list(itertools.chain.from_iterable(indices))\n                assert len(indices) == scores.size(0), f\"{len(indices)} vs. {scores.size()}\"\n                revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n                scores = scores[revert_indices]\n\n            token_level_scores = self._expand_to_token_level(data, scores)\n            # Note that this is only the scores, may not be the final rewards used to train RL\n            data.batch[\"rm_scores\"] = token_level_scores\n            processed_data = self.ulysses_sharding_manager.postprocess_data(data=data)\n\n        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes\n        # unshard the root FSDP module\n        if self.world_size > 1 and fsdp_version(self.reward_module) == 1:\n            self.reward_module._handle.reshard(True)\n\n        processed_data = processed_data.to(\"cpu\")\n        return processed_data\n\n\n# ================================= Async related workers =================================\nclass AsyncActorRolloutRefWorker(ActorRolloutRefWorker):\n    def _build_rollout(self, trust_remote_code=False):\n        rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code)\n\n        # NOTE: rollout is not actually initialized here, it's deferred\n        # to be initialized by AsyncvLLMServer.\n\n        self.vllm_tp_size = self.config.rollout.tensor_model_parallel_size\n        self.vllm_dp_rank = int(os.environ[\"RANK\"]) // self.vllm_tp_size\n        self.vllm_tp_rank = int(os.environ[\"RANK\"]) % self.vllm_tp_size\n\n        # used for sleep/wake_up\n        rollout.sharding_manager = rollout_sharding_manager\n\n        return rollout, rollout_sharding_manager\n\n    def generate_sequences(self, prompts: TensorDict):\n        raise NotImplementedError(\"AsyncActorRolloutRefWorker does not support generate_sequences\")\n\n    def execute_method(self, method: Union[str, bytes], *args, **kwargs):\n        \"\"\"Called by ExternalRayDistributedExecutor collective_rpc.\"\"\"\n        if self.vllm_tp_rank == 0 and method != \"execute_model\":\n            print(f\"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] execute_method: {method if isinstance(method, str) else 'Callable'}\")\n        return self.rollout.execute_method(method, *args, **kwargs)\n\n    async def chat_completion(self, json_request):\n        ret = await self.rollout.chat_completion(json_request)\n        return ret\n\n    async def wake_up(self):\n        await self.rollout.wake_up()\n        # return something to block the caller\n        return True\n\n    async def sleep(self):\n        await self.rollout.sleep()\n        # return something to block the caller\n        return True\n    \n    def set_rollout_sharding_manager(self, sharding_manager):      \n        super().set_rollout_sharding_manager(sharding_manager)\n        \n            \n    def get_zeromq_address(self):\n        return self.rollout.get_zeromq_address()\n\n"
  },
  {
    "path": "siirl/engine/megatron_workers.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, Infrawaves. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe main entry point to run the PPO algorithm\n\"\"\"\n\nimport os\nimport time\nimport warnings\nfrom typing import Union\nimport datetime\nimport psutil\n\nimport torch\nimport torch.distributed\nfrom codetiming import Timer\nfrom loguru import logger\n\nfrom omegaconf import DictConfig, OmegaConf\nfrom tensordict import TensorDict\n# Handle different tensordict versions - NonTensorData location varies\ntry:\n    from tensordict import NonTensorData\nexcept ImportError:\n    from tensordict.tensorclass import NonTensorData\ntry:\n    from mindspeed.megatron_adaptor import repatch\nexcept ImportError:\n    repatch = None\n\nfrom megatron.core import parallel_state as mpu\n\nfrom siirl.engine.base_worker.megatron.worker import MegatronWorker\nfrom siirl.models.loader import load_tokenizer\nfrom siirl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager\nfrom siirl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage\nfrom siirl.utils.model_utils.flops_counter import FlopsCounter\nfrom siirl.utils.extras.fs import copy_to_local\nfrom siirl.utils.megatron.megatron_utils import (\n    load_megatron_model_to_gpu,\n    load_megatron_optimizer,\n    offload_megatron_model_to_cpu,\n    offload_megatron_optimizer,\n)\nfrom siirl.utils.extras.import_utils import import_external_libs\nfrom siirl.utils.extras.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device\nfrom siirl.utils.model_utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights\nfrom siirl.utils.model_utils.torch_dtypes import PrecisionType\nfrom siirl.params.model_args import ActorRolloutRefArguments\nfrom siirl.engine.actor.megatron_actor import MegatronPPOActor\nfrom siirl.engine.critic.megatron_critic import MegatronPPOCritic\nfrom siirl.engine.reward_model.megatron.reward_model import MegatronRewardModel\n\n\ndef set_random_seed(seed):\n    import random\n\n    import numpy as np\n    import torch\n\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n    if get_torch_device().device_count() > 0:\n        from megatron.core import tensor_parallel\n\n        tensor_parallel.model_parallel_cuda_manual_seed(seed)\n    # FIXME: torch cumsum not support deterministic (used in vllm sampler),\n    # https://github.com/pytorch/pytorch/issues/89492\n    # torch.use_deterministic_algorithms(True, warn_only=True)\n    # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'\n\n\n# TODO(Ping Zhang): We will deprecate this hybrid worker in the future.\nclass ActorRolloutRefWorker(MegatronWorker):\n    \"\"\"\n    This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy\n    or a hybrid engine based on the config.rollout\n    \"\"\"\n\n    def __init__(self, config: DictConfig, role: str, process_group=None):\n        super().__init__()\n        self.config = config\n        global_mindspeed_repatch(self.config.megatron.override_transformer_config)\n        # self.process_group = process_group\n\n        # NOTE(sgm): We utilize colocate WorkerGroup by default.\n        # As a result, Workers for different model share the same process.\n        # Therefore, we only require one distribute initialization.\n        # To utilize different parallel startegy in different models:\n        # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models,\n        # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385\n        if not torch.distributed.is_initialized():\n            # Use LOCAL_RANK for device setting, but respect process group for distributed ops\n            rank = int(os.environ[\"LOCAL_RANK\"])\n            torch.distributed.init_process_group(backend=get_nccl_backend())\n            get_torch_device().set_device(rank)\n            if self.config.actor.megatron.sequence_parallel:\n                os.environ[\"CUDA_DEVICE_MAX_CONNECTIONS\"] = \"1\"\n        \n            mpu.initialize_model_parallel(\n                tensor_model_parallel_size=self.config.actor.megatron.tensor_model_parallel_size,\n                pipeline_model_parallel_size=self.config.actor.megatron.pipeline_model_parallel_size,\n                virtual_pipeline_model_parallel_size=self.config.actor.megatron.virtual_pipeline_model_parallel_size,\n                pipeline_model_parallel_split_rank=None,\n                use_sharp=False,\n                context_parallel_size=self.config.actor.megatron.context_parallel_size,\n                expert_model_parallel_size=self.config.actor.megatron.expert_model_parallel_size,\n                expert_tensor_parallel_size=self.config.actor.megatron.expert_tensor_parallel_size,\n                nccl_communicator_config_path=None,\n            )\n\n        set_random_seed(seed=self.config.actor.megatron.seed)\n\n        self.role = role\n        assert self.role in [\"actor\", \"rollout\", \"ref\", \"actor_rollout\", \"actor_rollout_ref\"]\n\n        self._is_actor = self.role in [\"actor\", \"actor_rollout\", \"actor_rollout_ref\"]\n        self._is_rollout = self.role in [\"rollout\", \"actor_rollout\", \"actor_rollout_ref\"]\n        self._is_ref = self.role in [\"ref\", \"actor_rollout_ref\"]\n\n        # TODO(sgm): Currently, we only support reference model param offload\n        # will support other offload later\n        self._is_offload_param = False\n        self._is_offload_grad = False\n        self._is_offload_optimizer = False\n\n        # normalize config\n        if self._is_actor and self._is_rollout:\n            self.config.actor.ppo_mini_batch_size *= self.config.rollout.n\n            self.config.actor.ppo_mini_batch_size //= mpu.get_data_parallel_world_size()\n            if self.config.actor.ppo_micro_batch_size:\n                self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size()\n                self.config.rollout.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()\n                self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size\n                self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size\n\n            self._is_offload_param = self.config.actor.megatron.param_offload\n            self._is_offload_grad = self.config.actor.megatron.grad_offload\n            self._is_offload_optimizer = self.config.actor.megatron.optimizer_offload\n        elif self._is_ref:\n            if self.config.ref.log_prob_micro_batch_size:\n                self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()\n                self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size\n            else:\n                assert self.config.ref.log_prob_micro_batch_size_per_gpu is not None, \"Please note that in the ref policy configuration, `log_prob_micro_batch_size_per_gpu` and `log_prob_micro_batch_size` should not be None at the same time.\"\n            self._ref_is_offload_param = self.config.ref.megatron.param_offload\n\n    def _build_model_optimizer(self, model_path, optim_config, override_model_config, override_transformer_config):\n        from megatron.core.models.gpt.gpt_model import ModelType\n\n        from siirl.utils.megatron.optimizer import get_megatron_optimizer\n        from siirl.utils.megatron.megatron_utils import get_model, init_megatron_optim_config\n        from siirl.utils.model_utils.model import get_generation_config, print_model_size\n\n        self._init_hf_config_and_tf_config(model_path, model_path, self.dtype, override_model_config, override_transformer_config, self.config.model.trust_remote_code)\n        self.generation_config = get_generation_config(self.local_path)\n\n        def megatron_actor_model_provider(pre_process, post_process):\n            from siirl.models.mcore import init_mcore_model\n\n            parallel_model = init_mcore_model(self.tf_config, self.hf_config, pre_process, post_process, share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, value=False, freeze_moe_router=override_model_config.get(\"moe_config\", {}).get(\"freeze_moe_router\", False))\n            parallel_model.to(get_device_name())\n            return parallel_model\n\n        actor_module = None\n        # Step 3: initialize the megatron model\n        if self._is_actor and self._is_rollout:\n            actor_module = get_model(\n                megatron_actor_model_provider,\n                wrap_with_ddp=True,\n                use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer,\n            )\n            print(f\"actor_module: {len(actor_module)}\")\n            if self.config.actor.load_weight:\n                if self.config.actor.megatron.use_dist_checkpointing:\n                    load_mcore_dist_weights(actor_module, self.config.actor.megatron.dist_checkpointing_path, is_value_model=False)\n                else:\n                    load_megatron_gptmodel_weights(self.config, self.hf_config, actor_module, params_dtype=self.dtype, is_value_model=False)\n\n            if self.rank == 0:\n                print_model_size(actor_module[0])\n            log_gpu_memory_usage(\"After MegatronPPOActor init\", logger=logger)\n        elif self._is_ref:\n            print(f\"self.config.ref.load_weight: {self.config.ref.load_weight}\")\n            ref_module = get_model(\n                model_provider_func=megatron_actor_model_provider,\n                model_type=ModelType.encoder_or_decoder,\n                wrap_with_ddp=False,\n                use_distributed_optimizer=self.config.ref.megatron.use_distributed_optimizer,\n            )\n            # ref_module = nn.ModuleList(ref_module)\n\n            if self.config.ref.load_weight:  # should align with the actor:\n                assert self.config.actor.load_weight == self.config.ref.load_weight\n                print(\"load ref weight start\")\n                if self.config.ref.megatron.use_dist_checkpointing:\n                    load_mcore_dist_weights(ref_module, self.config.ref.megatron.dist_checkpointing_path, is_value_model=False)\n                else:\n                    load_megatron_gptmodel_weights(self.config, self.hf_config, ref_module, params_dtype=self.dtype, is_value_model=False)\n            log_gpu_memory_usage(\"After ref module init\", logger=logger)\n            return ref_module, self.hf_config\n\n        # TODO: add more optimizer args into config\n        if self._is_actor:\n            optim_config = init_megatron_optim_config(optim_config)\n            actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config)\n        else:\n            optim_config = None\n            actor_optimizer = None\n\n        log_gpu_memory_usage(\"After actor optimizer init\", logger=logger)\n\n        return actor_module, actor_optimizer, self.hf_config, optim_config\n\n    def _build_rollout(self, trust_remote_code=False):\n        from torch.distributed.device_mesh import init_device_mesh\n\n        if self.config.rollout.name == \"vllm\":\n            from torch.distributed.device_mesh import init_device_mesh\n\n            from siirl.engine.rollout.vllm_rollout import vllm_mode, vLLMRollout\n            # NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor,\n            # we will reorganize their weight format when resharding from actor to rollout.\n\n            infer_tp = self.config.rollout.tensor_model_parallel_size\n            dp = self.world_size // infer_tp\n            assert self.world_size % infer_tp == 0, f\"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}\"\n            rollout_device_mesh = init_device_mesh(get_device_name(), mesh_shape=(dp, infer_tp), mesh_dim_names=[\"dp\", \"infer_tp\"])\n            log_gpu_memory_usage(\"Before building vllm rollout\", logger=None)\n\n            local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.use_shm)\n            if vllm_mode == \"customized\":\n                rollout = vLLMRollout(\n                    actor_module=self.actor_module,\n                    config=self.config.rollout,\n                    tokenizer=self.tokenizer,\n                    model_hf_config=self.actor_model_config,\n                )\n            elif vllm_mode == \"spmd\":\n                rollout = vLLMRollout(\n                    model_path=local_path,\n                    config=self.config.rollout,\n                    tokenizer=self.tokenizer,\n                    model_hf_config=self.actor_model_config,\n                    device_mesh=rollout_device_mesh,\n                    trust_remote_code=trust_remote_code,\n                )\n            log_gpu_memory_usage(\"After building vllm rollout\", logger=logger)\n\n        elif self.config.rollout.name in [\"sglang\", \"sglang_async\"]:\n            if self.config.rollout.name == \"sglang_async\":\n                warnings.warn(\n                    \"'sglang_async' has been deprecated and merged into 'sglang'. Please use 'sglang' going forward.\",\n                    DeprecationWarning,\n                    stacklevel=2,\n                )\n            from siirl.engine.rollout.sglang_rollout import SGLangRollout\n\n            # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability.\n            # However, due to siirl's setting, the main process of ray can not find any CUDA device, which would potentially lead to:\n            # \"RuntimeError: No CUDA GPUs are available\".\n            # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and we import it here use the abs path.\n            # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76\n            from siirl.engine.sharding_manager.megatron_sglang import MegatronSGLangShardingManager\n\n            infer_tp = self.config.rollout.tensor_model_parallel_size\n            dp = self.world_size // infer_tp\n            assert self.world_size % infer_tp == 0, f\"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}\"\n            rollout_device_mesh = init_device_mesh(\"cpu\", mesh_shape=(dp, infer_tp, 1), mesh_dim_names=(\"dp\", \"tp\", \"pp\"))\n\n            local_path = copy_to_local(self.config.model.path)\n            log_gpu_memory_usage(f\"Before building {self.config.rollout.name} rollout\", logger=None)\n            rollout = SGLangRollout(\n                actor_module=local_path,\n                config=self.config.rollout,\n                tokenizer=self.tokenizer,\n                model_hf_config=self.actor_model_config,\n                trust_remote_code=trust_remote_code,\n                device_mesh=rollout_device_mesh,\n            )\n            log_gpu_memory_usage(f\"After building {self.config.rollout.name} rollout\", logger=None)\n        else:\n            raise NotImplementedError(\"Only vllmRollout and SGLangRollout are supported with Megatron now\")\n        \n        print(\"rollout init done\")\n        return rollout, None\n\n    def init_model(self):\n        import_external_libs(self.config.model.external_lib)\n\n        override_model_config = self.config.model.override_config\n        if self._is_actor:\n            override_transformer_config = self.config.actor.megatron.override_transformer_config\n        elif self._is_ref:\n            override_transformer_config = self.config.ref.megatron.override_transformer_config\n        else:\n            override_transformer_config = None\n        \n        if not override_transformer_config:\n            override_transformer_config = OmegaConf.create()\n        \n        self.param_dtype = torch.bfloat16\n        log_gpu_memory_usage(\"Before init actor model and optimizer\", logger=logger)\n\n        self.dtype = PrecisionType.to_dtype(self.param_dtype)\n\n        if self._is_actor or self._is_rollout:\n            # we need the model for actor and rollout\n            optim_config = self.config.actor.optim if self._is_actor else None\n            self.actor_module, self.actor_optimizer, self.actor_model_config, self.actor_optim_config = self._build_model_optimizer(\n                model_path=self.config.model.path,\n                optim_config=optim_config,\n                override_model_config=override_model_config,\n                override_transformer_config=override_transformer_config,\n            )\n            if self._is_offload_param:\n                offload_megatron_model_to_cpu(self.actor_module)\n                log_gpu_memory_usage(\"After offload actor params and grad during init\", logger=logger)\n            if self._is_offload_optimizer:\n                offload_megatron_optimizer(self.actor_optimizer)\n                log_gpu_memory_usage(\"After offload actor optimizer during init\", logger=logger)\n\n        if self._is_actor:\n            self.actor = MegatronPPOActor(\n                config=self.config.actor,\n                model_config=self.actor_model_config,\n                hf_config=self.hf_config,\n                tf_config=self.tf_config,\n                actor_module=self.actor_module,\n                actor_optimizer=self.actor_optimizer,\n            )\n            log_gpu_memory_usage(\"After MegatronPPOActor init\", logger=logger)\n\n        if self._is_rollout:\n            self.rollout, self.sharding_manager = self._build_rollout(trust_remote_code=self.config.model.trust_remote_code)\n            # used for sleep/wake_up\n            self.rollout.sharding_manager = self.sharding_manager\n            log_gpu_memory_usage(\"After rollout init\", logger=logger)\n\n        if self._is_ref:\n            self.ref_module, self.ref_model_config = self._build_model_optimizer(\n                model_path=self.config.model.path,\n                optim_config=None,\n                override_model_config=override_model_config,\n                override_transformer_config=override_transformer_config,\n            )\n            log_gpu_memory_usage(\"After ref model init\", logger=logger)\n            self.ref_policy = MegatronPPOActor(\n                config=self.config.ref,\n                model_config=self.ref_model_config,\n                hf_config=self.hf_config,\n                tf_config=self.tf_config,\n                actor_module=self.ref_module,\n                actor_optimizer=None,\n            )\n            if self._ref_is_offload_param:\n                offload_megatron_model_to_cpu(self.ref_module)\n                log_gpu_memory_usage(\"After offload ref params during init\", logger=logger)\n\n        if self._is_actor:\n            self.flops_counter = FlopsCounter(self.actor_model_config)\n            self.checkpoint_mananager = MegatronCheckpointManager(\n                config=self.config,\n                model_config=self.actor_model_config,\n                role=\"actor\",\n                model=self.actor_module,\n                arch=self.architectures[0],\n                hf_config=self.hf_config,\n                param_dtype=self.param_dtype,\n                share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,\n                tokenizer=self.tokenizer,\n                optimizer=self.actor_optimizer,\n                use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer,\n                checkpoint_contents=self.config.actor.checkpoint.contents,\n            )\n        get_torch_device().empty_cache()\n        log_gpu_memory_usage(\"After init_model finish\", logger=logger)\n\n    @GPUMemoryLogger(role=\"update_actor\", logger=logger)\n    def update_actor(self, data: TensorDict):\n        assert self._is_actor\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.actor_module)\n            log_gpu_memory_usage(\"After load actor params and grad during update_actor\", logger=logger)\n        if self._is_offload_optimizer:\n            load_megatron_optimizer(self.actor_optimizer)\n            log_gpu_memory_usage(\"After load actor optimizer during update_actor\", logger=logger)\n        data = data.to(get_device_name())\n\n        micro_batch_size = self.config.actor.ppo_micro_batch_size_per_gpu\n        data[\"micro_batch_size\"] = NonTensorData(micro_batch_size)\n        with Timer(name=\"update_policy\", logger=None) as timer:\n            metrics = self.actor.update_policy(data=data)\n        delta_time = timer.last\n        global_num_tokens = data.meta_info[\"global_token_num\"]\n        estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n        metrics[\"perf/mfu/actor\"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size\n\n        # TODO: here, we should return all metrics\n        output = TensorDict(meta_info={\"metrics\": metrics})\n        output = output.to(\"cpu\")\n\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n            log_gpu_memory_usage(\"After offload actor params and grad during update_actor\", logger=logger)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.actor_optimizer)\n            log_gpu_memory_usage(\"After offload actor optimizer during update_actor\", logger=logger)\n\n        get_torch_device().empty_cache()\n        return output\n\n    @GPUMemoryLogger(role=\"generate_sequences\", logger=logger)\n    def generate_sequences(self, prompts: TensorDict):\n        assert self._is_rollout\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.actor_module)\n            log_gpu_memory_usage(\"After load actor params during generate_sequences\", logger=logger)\n        prompts.batch = prompts.batch.to(get_device_name())\n        meta_info = {\n            \"eos_token_id\": self.generation_config.eos_token_id if self.generation_config is not None else self.tokenizer.eos_token_id,\n            \"pad_token_id\": self.generation_config.pad_token_id if self.generation_config is not None else self.tokenizer.pad_token_id,\n        }\n        prompts.meta_info.update(meta_info)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.actor_optimizer)\n\n        with self.sharding_manager:\n            if self._is_offload_param:\n                offload_megatron_model_to_cpu(self.actor_module)\n            log_gpu_memory_usage(\"After entering sharding manager\", logger=logger)\n\n            # (zhangchi.usc1992) wake up kv cache here. Currently only support vllm.\n            # Will support sglang once separate wakeup of model weights and kv cache is supported\n            # This API should be exposed by the rollout. Will rewrite this part when we refactor after v0.4 release.\n            # Currently, we hack here to support running large models (QWen3-236b and DeepSeek-671b)\n            if self.config.rollout.name == \"vllm\":\n                import inspect\n\n                if \"tags\" in inspect.signature(self.rollout.inference_engine.wake_up).parameters:\n                    self.rollout.inference_engine.wake_up(tags=[\"kv_cache\"])\n\n            output = self.rollout.generate_sequences(prompts=prompts)\n\n        output = output.to(\"cpu\")\n        # clear kv cache\n        get_torch_device().empty_cache()\n        return output\n\n    def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True):\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.actor_module)\n        self.checkpoint_mananager.load_checkpoint(local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load)\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.actor_optimizer)\n\n    def load_pretrained_model(self, checkpoint_path, del_local_after_load=True):\n        pass\n\n    def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.actor_module)\n        self.checkpoint_mananager.save_checkpoint(local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep)\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n\n# TODO(Ping Zhang): We will deprecate this hybrid worker in the future.\nclass AsyncActorRolloutRefWorker(ActorRolloutRefWorker):\n    def _build_rollout(self, trust_remote_code=False):\n        rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code)\n\n        # NOTE: rollout is not actually initialized here, it's deferred\n        # to be initialized by AsyncvLLMServer.\n\n        self.vllm_tp_size = self.config.rollout.tensor_model_parallel_size\n        self.vllm_dp_rank = int(os.environ[\"RANK\"]) // self.vllm_tp_size\n        self.vllm_tp_rank = int(os.environ[\"RANK\"]) % self.vllm_tp_size\n\n        # used for sleep/wake_up\n        rollout.sharding_manager = rollout_sharding_manager\n\n        return rollout, rollout_sharding_manager\n\n    def execute_method(self, method: Union[str, bytes], *args, **kwargs):\n        \"\"\"Called by ExternalRayDistributedExecutor collective_rpc.\"\"\"\n        if self.vllm_tp_rank == 0 and method != \"execute_model\":\n            print(f\"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] execute_method: {method if isinstance(method, str) else 'Callable'}\")\n        return self.rollout.execute_method(method, *args, **kwargs)\n\n    async def chat_completion(self, json_request):\n        ret = await self.rollout.chat_completion(json_request)\n        return ret\n\n    async def wake_up(self):\n        await self.rollout.wake_up()\n        # return something to block the caller\n        return True\n\n    async def sleep(self):\n        await self.rollout.sleep()\n        # return something to block the caller\n        return True\n\n\nclass CriticWorker(MegatronWorker):\n    def __init__(self, config, process_group=None):\n        super().__init__()\n        self.config = config\n        # self.process_group = process_group\n\n        # NOTE(sgm): We utilize colocate WorkerGroup by default.\n        # As a result, Workers for different model share the same process.\n        # Therefore, we only require one distribute initialization.\n        # To utilize different parallel startegy in different models:\n        # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models,\n        # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385\n        global_mindspeed_repatch(self.config.megatron.override_transformer_config)\n        if not torch.distributed.is_initialized():\n            # Use LOCAL_RANK for device setting, but respect process group for distributed ops\n            rank = int(os.environ[\"LOCAL_RANK\"])\n            torch.distributed.init_process_group(backend=get_nccl_backend())\n            get_torch_device().set_device(rank)\n\n            if self.config.megatron.sequence_parallel:\n                os.environ[\"CUDA_DEVICE_MAX_CONNECTIONS\"] = \"1\"\n            mpu.initialize_model_parallel(\n                tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size,\n                pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size,\n                virtual_pipeline_model_parallel_size=self.config.megatron.virtual_pipeline_model_parallel_size,\n                pipeline_model_parallel_split_rank=None,\n                use_sharp=False,\n                context_parallel_size=self.config.megatron.context_parallel_size,\n                expert_model_parallel_size=self.config.megatron.expert_model_parallel_size,\n                expert_tensor_parallel_size=self.config.megatron.expert_tensor_parallel_size,\n                nccl_communicator_config_path=None,\n            )\n\n        set_random_seed(seed=self.config.megatron.seed)\n\n        # set FSDP offload params\n        self._is_offload_param = self.config.megatron.param_offload\n        self._is_offload_optimizer = self.config.megatron.optimizer_offload\n\n        # normalize config\n        self.config.ppo_mini_batch_size *= self.config.rollout_n\n        self.config.ppo_mini_batch_size //= mpu.get_data_parallel_world_size()\n        if self.config.ppo_micro_batch_size:\n            self.config.ppo_micro_batch_size //= mpu.get_data_parallel_world_size()\n            self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size\n\n        # TODO(sgm): support critic model offload\n\n    def _build_critic_model_optimizer(self, model_path, optim_config, override_model_config, override_transformer_config, override_ddp_config):\n        from siirl.utils.megatron.optimizer import get_megatron_optimizer, get_megatron_optimizer_param_scheduler\n        from siirl.utils.megatron.megatron_utils import init_megatron_optim_config\n        from siirl.utils.model_utils.model import print_model_size\n        from siirl.utils.megatron.megatron_utils import McoreModuleWrapperConfig, make_megatron_module\n\n        self._init_hf_config_and_tf_config(\n            model_path, \n            model_path, \n            self.dtype, \n            override_model_config, \n            override_transformer_config, \n            self.config.model.trust_remote_code,\n            self.config.megatron.use_mbridge,\n        )\n\n        wrap_config = McoreModuleWrapperConfig(\n            is_value_model=True,  # critic is value model\n            share_embeddings_and_output_weights=False,\n            wrap_with_ddp=True,\n            use_distributed_optimizer=self.config.megatron.use_distributed_optimizer,\n        )\n        critic_module = make_megatron_module(\n            wrap_config=wrap_config,\n            tf_config=self.tf_config,\n            hf_config=self.hf_config,\n            bridge=self.bridge,\n            override_model_config=override_model_config,\n            override_ddp_config=override_ddp_config,\n        )\n\n        # note that here critic_module will be a list to be compatible with the construction of interleaved pp (vpp).\n        # but here, we do not use pp (vpp) yet. For simplicity, we remove the list\n        # critic_module = nn.ModuleList(critic_module)\n\n        if self.config.load_weight:\n            t0 = time.time()\n            if self.config.megatron.use_dist_checkpointing:\n                load_mcore_dist_weights(\n                    critic_module, self.config.megatron.dist_checkpointing_path, is_value_model=True\n                )\n            else:\n                if self.bridge is not None:\n                    local_model_path = get_hf_model_path(self.config)\n                    self.bridge.load_weights(critic_module, local_model_path)\n                else:\n                    load_megatron_gptmodel_weights(\n                        self.config, self.hf_config, critic_module, params_dtype=self.dtype, is_value_model=True\n                    )\n            t1 = time.time()\n            if torch.distributed.get_rank() == 0:\n                print(f\"critic load_weight time: {t1 - t0}\")\n        if self.rank == 0:\n            print_model_size(critic_module[0])\n\n        # TODO: add more optimizer args into config\n        optim_config_megatron = init_megatron_optim_config(optim_config)\n        critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config_megatron)\n        critic_optimizer_scheduler = get_megatron_optimizer_param_scheduler(\n            optimizer=critic_optimizer, config=optim_config\n        )\n        get_torch_device().empty_cache()\n        return critic_module, critic_optimizer, critic_optimizer_scheduler, self.hf_config, optim_config\n\n    def init_model(self):\n        # create critic\n        import_external_libs(self.config.model.external_lib)\n        override_model_config = self.config.model.override_config\n        override_transformer_config = self.config.megatron.override_transformer_config\n\n        if not override_transformer_config:\n            override_transformer_config = OmegaConf.create()\n        \n        override_ddp_config = self.config.megatron.override_ddp_config\n        if not override_ddp_config:\n            override_ddp_config = OmegaConf.create()\n        \n        self.param_dtype = torch.bfloat16\n        self.dtype = PrecisionType.to_dtype(self.param_dtype)\n        self.critic_module, self.critic_optimizer, self.critic_optimizer_scheduler, self.critic_model_config, critic_optimizer_config = self._build_critic_model_optimizer(\n            model_path=self.config.model.path,\n            optim_config=self.config.optim,\n            override_model_config=override_model_config,\n            override_transformer_config=override_transformer_config,\n            override_ddp_config=override_ddp_config,\n        )\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.critic_module)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.critic_optimizer)\n\n        self.critic = MegatronPPOCritic(\n            config=self.config,\n            model_config=self.critic_model_config,\n            hf_config=self.hf_config,\n            tf_config=self.tf_config,\n            critic_module=self.critic_module,\n            critic_optimizer=self.critic_optimizer,\n            critic_optimizer_config=critic_optimizer_config,\n        )\n        self.flops_counter = FlopsCounter(self.critic_model_config)\n        self.checkpoint_mananager = MegatronCheckpointManager(\n            config=self.config,\n            checkpoint_config=self.config.checkpoint,\n            model_config=self.critic_model_config,\n            transformer_config=self.tf_config,\n            role=\"critic\",\n            model=self.critic_module,\n            arch=self.architectures[0],\n            hf_config=self.hf_config,\n            param_dtype=self.param_dtype,\n            share_embeddings_and_output_weights=False,\n            processing_class=self.processor if self.processor is not None else self.tokenizer,\n            optimizer=self.critic_optimizer,\n            optimizer_scheduler=self.critic_optimizer_scheduler,\n            use_distributed_optimizer=self.config.megatron.use_distributed_optimizer,\n            use_checkpoint_opt_param_scheduler=self.config.optim.use_checkpoint_opt_param_scheduler,\n            bridge=self.bridge,\n            use_dist_checkpointing=self.config.megatron.use_dist_checkpointing,\n        )\n\n    def compute_values(self, data: TensorDict):\n        micro_batch_size = self.config.ppo_micro_batch_size_per_gpu\n        data[\"micro_batch_size\"] = NonTensorData(micro_batch_size)\n        data[\"max_token_len\"] = NonTensorData(self.config.forward_max_token_len_per_gpu)\n        data[\"use_dynamic_bsz\"] = NonTensorData(self.config.use_dynamic_bsz)\n        data = data.to(get_device_id())\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.critic_module)\n        values = self.critic.compute_values(data=data)\n        data[\"values\"] = values\n        data = data.to(\"cpu\")\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.critic_module)\n        return data\n\n    def update_critic(self, data: TensorDict):\n        data = data.to(get_device_id())\n\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.critic_module)\n        if self._is_offload_optimizer:\n            load_megatron_optimizer(self.critic_optimizer)\n        with Timer(name=\"update_critic\", logger=None) as timer:\n            metrics = self.critic.update_critic(data=data)\n        delta_time = timer.last\n        global_num_tokens = data[\"global_token_num\"]\n        estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n        metrics[\"perf/mfu/critic\"] = estimated_flops * self.config.ppo_epochs / promised_flops\n        metrics[\"perf/delta_time/critic\"] = delta_time\n        data[\"metrics\"] = NonTensorData(metrics)\n        data = data.to(\"cpu\")\n\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.critic_module)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.critic_optimizer)\n        \n        return data\n\n    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.critic_module)\n        self.checkpoint_mananager.load_checkpoint(local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load)\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.critic_module)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.critic_optimizer)\n\n    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.critic_module)\n        self.checkpoint_mananager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep)\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.critic_module)\n\n\nclass RewardModelWorker(MegatronWorker):\n    \"\"\"\n    Note that we only implement the reward model that is subclass of AutoModelForSequenceClassification.\n    \"\"\"\n\n    def __init__(self, config, process_group=None):\n        super().__init__()\n        self.config = config\n        # self.process_group = process_group\n\n        # NOTE(sgm): We utilize colocate WorkerGroup by default.\n        # As a result, Workers for different model share the same process.\n        # Therefore, we only require one distribute initialization.\n        # To utilize different parallel startegy in different models:\n        # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models,\n        # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385\n        global_mindspeed_repatch(self.config.actor.megatron.get(\"override_transformer_config\", {}))\n        if not torch.distributed.is_initialized():\n            rank = int(os.environ[\"LOCAL_RANK\"])\n            torch.distributed.init_process_group(backend=get_nccl_backend())\n            get_torch_device().set_device(rank)\n            if self.config.megatron.sequence_parallel:\n                os.environ[\"CUDA_DEVICE_MAX_CONNECTIONS\"] = \"1\"\n            mpu.initialize_model_parallel(\n                tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size,\n                pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size,\n                virtual_pipeline_model_parallel_size=self.config.megatron.virtual_pipeline_model_parallel_size,\n                pipeline_model_parallel_split_rank=None,\n                use_sharp=False,\n                context_parallel_size=self.config.megatron.context_parallel_size,\n                expert_model_parallel_size=self.config.megatron.expert_model_parallel_size,\n                expert_tensor_parallel_size=self.config.megatron.expert_tensor_parallel_size,\n                nccl_communicator_config_path=None,\n            )\n\n        set_random_seed(seed=self.config.megatron.seed)\n\n        # normalize config\n        if self.config.micro_batch_size is not None:\n            self.config.micro_batch_size //= mpu.get_data_parallel_world_size()\n            self.config.micro_batch_size_per_gpu = self.config.micro_batch_size\n\n    def _build_rm_model(self, model_path, tokenizer, override_model_config, override_transformer_config):\n        from siirl.utils.megatron.megatron_utils import McoreModuleWrapperConfig, make_megatron_module\n\n        self._init_hf_config_and_tf_config(\n            model_path,\n            tokenizer,\n            self.dtype,\n            override_model_config,\n            override_transformer_config,\n            self.config.model.trust_remote_code,\n            self.config.megatron.use_mbridge,\n        )\n\n        wrap_config = McoreModuleWrapperConfig(\n            is_value_model=True,  # reward model is value model\n            share_embeddings_and_output_weights=False,\n            wrap_with_ddp=False,\n            use_distributed_optimizer=self.config.megatron.use_distributed_optimizer,\n        )\n        reward_model = make_megatron_module(\n            wrap_config=wrap_config,\n            tf_config=self.tf_config,\n            hf_config=self.hf_config,\n            bridge=self.bridge,\n            override_model_config=override_model_config,\n        )\n\n        if self.config.load_weight:\n            if self.config.megatron.use_dist_checkpointing:\n                load_mcore_dist_weights(reward_model, self.config.megatron.dist_checkpointing_path, is_value_model=True)\n            else:\n                if self.bridge is not None:\n                    local_model_path = get_hf_model_path(self.config)\n                    self.bridge.load_weights(reward_model, local_model_path)\n                else:\n                    load_megatron_gptmodel_weights(\n                        self.config, self.hf_config, reward_model, params_dtype=self.dtype, is_value_model=True\n                    )\n\n        get_torch_device().empty_cache()\n        return reward_model, self.hf_config\n\n    def init_model(self):\n        # create critic\n        import_external_libs(self.config.model.external_lib)\n        override_model_config = self.config.model.override_config\n        override_transformer_config = self.config.model.override_transformer_config\n\n        if not override_transformer_config:\n            override_transformer_config = OmegaConf.create()\n\n        use_shm = self.config.model.use_shm\n        sft_tokenizer_local_path = copy_to_local(self.config.model.input_tokenizer, use_shm=use_shm)\n        sft_tokenizer = load_tokenizer(path=sft_tokenizer_local_path)[\"tokenizer\"]\n        rm_tokenizer_path = self.config.model.rm_tokenizer\n        rm_tokenizer = None\n        if rm_tokenizer_path is not None:\n            rm_tokenizer_local_path = copy_to_local(rm_tokenizer_path, use_shm=use_shm)\n            rm_tokenizer = load_tokenizer(path=rm_tokenizer_local_path)[\"tokenizer\"]\n\n        self.param_dtype = torch.bfloat16\n        self.dtype = PrecisionType.to_dtype(self.param_dtype)\n\n        reward_model_module, reward_model_config = self._build_rm_model(\n            model_path=self.config.model.path,\n            tokenizer=rm_tokenizer,\n            override_model_config=override_model_config,\n            override_transformer_config=override_transformer_config,\n        )\n        # FIXME(sgm): reward model param offload is implemented in MegatronRewardModel\n        # should be implemented in workers\n        self.rm = MegatronRewardModel(\n            config=self.config,\n            reward_model_module=reward_model_module,\n            model_config=reward_model_config,\n            hf_config=self.hf_config,\n            tf_config=self.tf_config,\n            sft_tokenizer=sft_tokenizer,\n            rm_tokenizer=rm_tokenizer,\n        )\n\n    # TODO: reward model use itself tokenizer instead of sft tokenizer\n    # the input_ids, responses, attention_mask and position_ids may be different!\n    def compute_rm_score(self, data: TensorDict):\n        data.meta_info[\"micro_batch_size\"] = self.config.micro_batch_size_per_gpu\n        data.meta_info[\"max_token_len\"] = self.config.forward_max_token_len_per_gpu\n        data.meta_info[\"use_dynamic_bsz\"] = self.config.use_dynamic_bsz\n        data = data.to(get_device_id())\n        output = self.rm.compute_reward(data)\n        output = output.to(\"cpu\")\n        return output\n\n\n# ================================= Separated Workers =================================\n\nIS_ACTOR_ROLLOUT_REF_INITIALIZED = False\n\ndef global_initialize_model_parallel(config: ActorRolloutRefArguments):\n    # For separated workers, we use actor's megatron config for distributed model initialization\n    megatron_config = config.actor.megatron\n    \n    rank = int(os.environ[\"LOCAL_RANK\"])\n    if not torch.distributed.is_initialized():\n        torch.distributed.init_process_group(\n            backend=get_nccl_backend(),\n            timeout=datetime.timedelta(seconds=600),\n            init_method=os.environ.get(\"DIST_INIT_METHOD\", None),\n        )\n    get_torch_device().set_device(rank)\n\n    global IS_ACTOR_ROLLOUT_REF_INITIALIZED\n    if IS_ACTOR_ROLLOUT_REF_INITIALIZED:\n        return\n\n    if megatron_config.sequence_parallel:\n        os.environ[\"CUDA_DEVICE_MAX_CONNECTIONS\"] = \"1\"\n    \n    mpu.initialize_model_parallel(\n            tensor_model_parallel_size=megatron_config.tensor_model_parallel_size,\n            pipeline_model_parallel_size=megatron_config.pipeline_model_parallel_size,\n            virtual_pipeline_model_parallel_size=megatron_config.virtual_pipeline_model_parallel_size,\n            pipeline_model_parallel_split_rank=None,\n            use_sharp=False,\n            context_parallel_size=megatron_config.context_parallel_size,\n            expert_model_parallel_size=megatron_config.expert_model_parallel_size,\n            expert_tensor_parallel_size=megatron_config.expert_tensor_parallel_size,\n            nccl_communicator_config_path=None,\n        )    \n    set_random_seed(seed=megatron_config.seed)\n    \n    IS_ACTOR_ROLLOUT_REF_INITIALIZED = True\n\n\nIS_MINDSPEED_REPATCH = False\n\ndef global_mindspeed_repatch(config):\n    \"\"\"\n    Use for Mindspeed repatch global once\n    \"\"\"\n\n    global IS_MINDSPEED_REPATCH\n    if repatch is not None and not IS_MINDSPEED_REPATCH:\n        # NPU MindSpeed patch, will be refactored with MindSpeedEngine.\n        repatch(config)\n        IS_MINDSPEED_REPATCH = True\n\n\nclass ActorWorker(MegatronWorker):\n    \"\"\"\n    Dedicated worker for actor training\n    \"\"\"\n\n    def __init__(self, config: DictConfig, process_group=None):\n        # For backward compatibility, we do not seperate the hybrid configurations,\n        # i.e., the `config` here is still ActorRolloutRefArguments\n        assert isinstance(config, ActorRolloutRefArguments), \"config of ActorWorker must be ActorRolloutRefArguments\"\n        super().__init__()\n        self.config = config\n        global_mindspeed_repatch(self.config.actor.megatron.to_dict().get(\"override_transformer_config\", {}))\n        global_initialize_model_parallel(self.config)\n\n        self.config.actor.ppo_mini_batch_size *= self.config.rollout.n\n        self.config.actor.ppo_mini_batch_size //= mpu.get_data_parallel_world_size()\n        if self.config.actor.ppo_micro_batch_size:\n            self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size()\n            self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size\n\n        self._is_offload_param = self.config.actor.megatron.param_offload\n        self._is_offload_grad = self.config.actor.megatron.grad_offload\n        self._is_offload_optimizer = self.config.actor.megatron.optimizer_offload\n\n    def _build_actor_model_optimizer(self, model_path, optim_config, override_model_config, override_transformer_config, override_ddp_config):\n        from siirl.utils.megatron.megatron_utils import init_megatron_optim_config\n        from siirl.utils.model_utils.model import print_model_size\n        from siirl.utils.megatron.megatron_utils import McoreModuleWrapperConfig, make_megatron_module\n        from siirl.utils.megatron.optimizer import get_megatron_optimizer, get_megatron_optimizer_param_scheduler\n\n        self._init_hf_config_and_tf_config(\n            model_path,\n            model_path,\n            self.dtype,\n            override_model_config,\n            override_transformer_config,\n            self.config.model.trust_remote_code,\n            self.config.actor.megatron.use_mbridge,\n        )\n        wrap_config = McoreModuleWrapperConfig(\n            is_value_model=False,  # actor is not value model\n            share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,\n            wrap_with_ddp=True,\n            use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer,\n        )\n        actor_module = make_megatron_module(\n            wrap_config=wrap_config,\n            tf_config=self.tf_config,\n            hf_config=self.hf_config,\n            bridge=self.bridge,\n            override_model_config=override_model_config,\n            override_ddp_config=override_ddp_config,\n        )\n        print(f\"actor_module: {len(actor_module)}\")\n        if self.config.actor.load_weight:\n            if self.config.actor.megatron.use_dist_checkpointing:\n                load_mcore_dist_weights(\n                    actor_module, self.config.actor.megatron.dist_checkpointing_path, is_value_model=False\n                )\n            else:\n                if self.bridge is not None:\n                    local_model_path = get_hf_model_path(self.config)\n                    self.bridge.load_weights(actor_module, local_model_path)\n                else:\n                    load_megatron_gptmodel_weights(\n                        self.config, self.hf_config, actor_module, params_dtype=self.dtype, is_value_model=False\n                    )\n\n        if self.rank == 0:\n            print_model_size(actor_module[0])\n        log_gpu_memory_usage(\"After MegatronPPOActor init\", logger=logger)\n\n        # TODO: add more optimizer args into config\n        optim_megatron_config = init_megatron_optim_config(optim_config)\n        actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_megatron_config)\n        actor_optimizer_scheduler = get_megatron_optimizer_param_scheduler(\n                optimizer=actor_optimizer, config=optim_config\n            )\n\n        log_gpu_memory_usage(\"After actor optimizer init\", logger=logger)\n\n        return actor_module, actor_optimizer, actor_optimizer_scheduler, self.hf_config, optim_config\n\n    def init_model(self):\n        import_external_libs(self.config.model.external_lib)\n\n        override_model_config = self.config.model.override_config\n        override_transformer_config = self.config.actor.megatron.override_transformer_config\n        \n        if not override_transformer_config:\n            override_transformer_config = OmegaConf.create()\n        \n        override_ddp_config = self.config.actor.megatron.override_ddp_config\n        if not override_ddp_config:\n            override_ddp_config = OmegaConf.create()\n        \n        self.param_dtype = torch.bfloat16\n        log_gpu_memory_usage(\"Before init actor model and optimizer\", logger=logger)\n\n        self.dtype = PrecisionType.to_dtype(self.param_dtype)\n\n        # we need the model for actor\n        optim_config = self.config.actor.optim\n        self.actor_module, self.actor_optimizer, self.actor_optimizer_scheduler, self.actor_model_config, self.actor_optim_config = self._build_actor_model_optimizer(\n            model_path=self.config.model.path,\n            optim_config=optim_config,\n            override_model_config=override_model_config,\n            override_transformer_config=override_transformer_config,\n            override_ddp_config=override_ddp_config,\n        )\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n            log_gpu_memory_usage(\"After offload actor params and grad during init\", logger=logger)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.actor_optimizer)\n            log_gpu_memory_usage(\"After offload actor optimizer during init\", logger=logger)\n\n        self.actor = MegatronPPOActor(\n            config=self.config.actor,\n            model_config=self.actor_model_config,\n            hf_config=self.hf_config,\n            tf_config=self.tf_config,\n            actor_module=self.actor_module,\n            actor_optimizer=self.actor_optimizer,\n        )\n        log_gpu_memory_usage(\"After MegatronPPOActor init\", logger=logger)\n\n        self.flops_counter = FlopsCounter(self.actor_model_config)\n        self.checkpoint_mananager = MegatronCheckpointManager(\n            config=self.config,\n            checkpoint_config=self.config.actor.checkpoint,\n            model_config=self.actor_model_config,\n            transformer_config=self.tf_config,\n            role=\"actor\",\n            model=self.actor_module,\n            arch=self.architectures[0],\n            hf_config=self.hf_config,\n            param_dtype=self.param_dtype,\n            share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,\n            processing_class=self.processor if self.processor is not None else self.tokenizer,\n            optimizer=self.actor_optimizer,\n            optimizer_scheduler=self.actor_optimizer_scheduler,\n            use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer,\n            use_checkpoint_opt_param_scheduler=self.config.actor.optim.use_checkpoint_opt_param_scheduler,\n            bridge=self.bridge,\n            use_dist_checkpointing=self.config.actor.megatron.use_dist_checkpointing,\n        )\n        get_torch_device().empty_cache()\n        log_gpu_memory_usage(\"After init_model finish\", logger=logger)\n\n    @GPUMemoryLogger(role=\"update_actor\", logger=logger)\n    def update_actor(self, data: TensorDict):\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.actor_module)\n            log_gpu_memory_usage(\"After load actor params and grad during update_actor\", logger=logger)\n        if self._is_offload_optimizer:\n            load_megatron_optimizer(self.actor_optimizer)\n            log_gpu_memory_usage(\"After load actor optimizer during update_actor\", logger=logger)\n        data = data.to(get_device_name())\n\n        micro_batch_size = self.config.actor.ppo_micro_batch_size_per_gpu\n        data[\"micro_batch_size\"] = NonTensorData(micro_batch_size)\n        with Timer(name=\"update_policy\", logger=None) as timer:\n            metrics = self.actor.update_policy(data=data)\n        delta_time = timer.last\n        global_num_tokens = data[\"global_token_num\"]\n        estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n        metrics[\"perf/mfu/actor\"] = estimated_flops / promised_flops\n        metrics[\"perf/delta_time/actor\"] = delta_time\n        metrics[\"perf/max_memory_allocated_gb\"] = get_torch_device().max_memory_allocated() / (1024**3)\n        metrics[\"perf/max_memory_reserved_gb\"] = get_torch_device().max_memory_reserved() / (1024**3)\n        metrics[\"perf/cpu_memory_used_gb\"] = psutil.virtual_memory().used / (1024**3)\n\n        # TODO: here, we should return all metrics\n        data[\"metrics\"] = NonTensorData(metrics)\n        data = data.to(\"cpu\")\n\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n            log_gpu_memory_usage(\"After offload actor params and grad during update_actor\", logger=logger)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.actor_optimizer)\n            log_gpu_memory_usage(\"After offload actor optimizer during update_actor\", logger=logger)\n\n        return data\n\n    @GPUMemoryLogger(role=\"compute_log_prob\", logger=logger)\n    def compute_log_prob(self, data: TensorDict):\n        torch.cuda.synchronize()\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.actor_module, load_grad=False)\n            log_gpu_memory_usage(\"After load actor params during compute_log_prob\", logger=logger)\n\n        # we should always recompute old_log_probs when it is HybridEngine\n        data[\"micro_batch_size\"] = NonTensorData(self.config.rollout.log_prob_micro_batch_size_per_gpu)\n        data[\"max_token_len\"] = NonTensorData(self.config.rollout.log_prob_max_token_len_per_gpu)\n        data[\"use_dynamic_bsz\"] = NonTensorData(self.config.rollout.log_prob_use_dynamic_bsz)\n        data[\"temperature\"] = NonTensorData(self.config.rollout.temperature)\n        data = data.to(get_device_id())\n        with Timer(name=\"compute_log_prob\", logger=None) as timer:\n            output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)\n        delta_time = timer.last\n\n        # store results of Actor old_log_probs\n        data[\"old_log_probs\"] = output\n        data[\"entropys\"] = entropys\n\n        # update metrics\n        metrics = {}\n        global_num_tokens = data[\"global_token_num\"]\n        estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n        metrics[\"perf/mfu/actor_log_prob\"] = estimated_flops / promised_flops\n        metrics[\"perf/delta_time/actor_log_prob\"] = delta_time\n        data[\"metrics\"] = NonTensorData(metrics)\n        data = data.to(\"cpu\")\n        # clear kv cache\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n            log_gpu_memory_usage(\"After offload actor params and grad during compute_log_prob\", logger=logger)\n        return data\n\n    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.actor_module)\n        self.checkpoint_mananager.load_checkpoint(local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load)\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n        if self._is_offload_optimizer:\n            offload_megatron_optimizer(self.actor_optimizer)\n\n    def load_pretrained_model(self, checkpoint_path, del_local_after_load=True):\n        pass\n\n    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):\n        if self._is_offload_param:\n            load_megatron_model_to_gpu(self.actor_module)\n        self.checkpoint_mananager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep)\n        torch.distributed.barrier()\n        if self._is_offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n\n\nclass RolloutWorker(MegatronWorker):\n    \"\"\"\n    Dedicated worker for rollout inference\n    \"\"\"\n\n    def __init__(self, config: DictConfig, process_group=None):\n        # For backward compatibility, we do not seperate the hybrid configurations,\n        # i.e., the `config` here is still ActorRolloutRefArguments\n        assert isinstance(config, ActorRolloutRefArguments), \"config of RolloutWorker must be ActorRolloutRefArguments\"\n        super().__init__()\n        self.config = config\n        global_mindspeed_repatch(self.config.actor.megatron.to_dict().get(\"override_transformer_config\", {}))\n\n        # normalize rollout config\n        global_initialize_model_parallel(self.config)\n\n        if self.config.rollout.log_prob_micro_batch_size:\n            self.config.rollout.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()\n            self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size\n        \n        self.device_mesh = None\n\n    def _build_rollout(self, trust_remote_code=False):\n        from torch.distributed.device_mesh import init_device_mesh\n        infer_tp = self.config.rollout.tensor_model_parallel_size\n        dp = self.world_size // infer_tp\n        assert self.world_size % infer_tp == 0, f\"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}\"\n\n        if self.config.rollout.name == \"vllm\":\n            from siirl.engine.rollout.vllm_rollout import vLLMRollout\n            # NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor,\n            # we will reorganize their weight format when resharding from actor to rollout.\n            rollout_device_mesh = init_device_mesh(get_device_name(), mesh_shape=(dp, infer_tp), mesh_dim_names=[\"dp\", \"infer_tp\"])\n            self.device_mesh = rollout_device_mesh\n            log_gpu_memory_usage(\"Before building vllm rollout\", logger=None)\n            # local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.use_shm)\n            rollout = vLLMRollout(\n                model_path=self.local_path,\n                config=self.config.rollout,\n                tokenizer=self.tokenizer,\n                model_hf_config=self.hf_config,\n                device_mesh=rollout_device_mesh,\n                trust_remote_code=trust_remote_code,\n            )\n            log_gpu_memory_usage(\"After building vllm rollout\", logger=logger)\n\n        elif self.config.rollout.name in [\"sglang\", \"sglang_async\"]:\n            from siirl.engine.rollout.sglang_rollout import SGLangRollout\n            if self.config.rollout.name == \"sglang_async\":\n                warnings.warn(\n                    \"'sglang_async' has been deprecated and merged into 'sglang'. Please use 'sglang' going forward.\",\n                    DeprecationWarning,\n                    stacklevel=2,\n                )\n            rollout_device_mesh = init_device_mesh(\"cpu\", mesh_shape=(dp, infer_tp, 1), mesh_dim_names=(\"dp\", \"tp\", \"pp\"))\n            self.device_mesh = rollout_device_mesh\n            # local_path = copy_to_local(self.config.model.path)\n            log_gpu_memory_usage(f\"Before building {self.config.rollout.name} rollout\", logger=None)\n            rollout = SGLangRollout(\n                actor_module=self.local_path,\n                config=self.config.rollout,\n                tokenizer=self.tokenizer,\n                model_hf_config=self.hf_config,\n                trust_remote_code=trust_remote_code,\n                processing_class=self.processor if self.processor is not None else self.tokenizer,\n                device_mesh=rollout_device_mesh,\n            )\n            log_gpu_memory_usage(f\"After building {self.config.rollout.name} rollout\", logger=None)\n        else:\n            raise NotImplementedError(\"Only vllmRollout and SGLangRollout are supported with Megatron now\")\n        \n        return rollout, None\n\n    def init_model(self):\n        import_external_libs(self.config.model.external_lib)\n        self.param_dtype = torch.bfloat16\n        log_gpu_memory_usage(\"Before init rollout inference engine\", logger=logger)\n\n        self.dtype = PrecisionType.to_dtype(self.param_dtype)\n\n        # Initialize HF config and tokenizer for inference engine setup\n        from siirl.utils.model_utils.model import get_generation_config\n        \n        # self.local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.use_shm)\n        # self.tokenizer = load_tokenizer(model_args=self.config.model)['tokenizer']\n        model_path = self.config.model.path\n        model_override_config = self.config.model.override_config\n        model_override_transformer_config = self.config.actor.megatron.override_transformer_config\n        model_trust_remote_code = self.config.model.trust_remote_code\n        \n        self._init_hf_config_and_tf_config(\n            model_path, model_path, \n            self.dtype, \n            model_override_config, \n            model_override_transformer_config, \n            model_trust_remote_code, \n            False, # mbridge is not used for rollout\n        )\n        \n        self.generation_config = get_generation_config(self.local_path)\n\n        # Only build the inference engine (vLLM/SGLang) - no need for Megatron model\n        self.rollout, self.sharding_manager = self._build_rollout(trust_remote_code=self.config.model.trust_remote_code)\n        get_torch_device().empty_cache()\n        log_gpu_memory_usage(\"After rollout init\", logger=logger)\n\n    @GPUMemoryLogger(role=\"generate_sequences\", logger=logger)\n    def generate_sequences(self, prompts: TensorDict):\n        prompts = prompts.to(get_device_id())\n        prompts[\"eos_token_id\"] = NonTensorData(self.generation_config.eos_token_id if self.generation_config is not None else self.tokenizer.eos_token_id)\n        prompts[\"pad_token_id\"] = NonTensorData(self.generation_config.pad_token_id if self.generation_config is not None else self.tokenizer.pad_token_id)\n\n        with self.sharding_manager:\n            log_gpu_memory_usage(\"After entering sharding manager\", logger=logger)\n            with Timer(name=\"generate_sequences\", logger=None) as timer:\n                output = self.rollout.generate_sequences(prompts=prompts)\n            delta_time = timer.last\n            # Note: Add metrics for Rollout, we may use them later.\n            metrics = {}\n            metrics[\"perf/delta_time/rollout\"] = delta_time\n        log_gpu_memory_usage(\"After rollout generation\", logger=logger)\n        output[\"metrics\"] = NonTensorData(metrics, batch_size=None)\n        output = output.to(\"cpu\")\n        # clear kv cache\n        get_torch_device().empty_cache()\n        return output\n    \n    def set_rollout_sharding_manager(self, sharding_manager):\n        self.sharding_manager = sharding_manager\n\n\nclass ReferenceWorker(MegatronWorker):\n    \"\"\"\n    Dedicated worker for reference policy\n    \"\"\"\n\n    def __init__(self, config: DictConfig, process_group=None):\n        # For backward compatibility, we do not seperate the hybrid configurations,\n        # i.e., the `config` here is still ActorRolloutRefArguments\n        assert isinstance(config, ActorRolloutRefArguments), \"config must be ActorRolloutRefArguments\"\n        super().__init__()\n        self.config = config\n        global_mindspeed_repatch(self.config.actor.megatron.to_dict().get(\"override_transformer_config\", {}))\n        global_initialize_model_parallel(self.config)\n\n        # normalize ref config\n        if self.config.ref.log_prob_micro_batch_size:\n            self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size()\n            self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size\n        else:\n            assert self.config.ref.log_prob_micro_batch_size_per_gpu is not None, \"Please note that in the ref policy configuration, `log_prob_micro_batch_size_per_gpu` and `log_prob_micro_batch_size` should not be None at the same time.\"\n        \n        self._ref_is_offload_param = self.config.ref.megatron.param_offload\n\n    def _build_ref_model(self, model_path, override_model_config, override_transformer_config):\n        from siirl.utils.megatron.megatron_utils import McoreModuleWrapperConfig, make_megatron_module\n        self._init_hf_config_and_tf_config(\n            model_path,\n            model_path,\n            self.dtype,\n            override_model_config,\n            override_transformer_config,\n            self.config.model.trust_remote_code,\n            self.config.actor.megatron.use_mbridge,\n        )\n        wrap_config = McoreModuleWrapperConfig(\n            is_value_model=False,  # ref is not value model\n            share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,\n            wrap_with_ddp=False,\n            use_distributed_optimizer=self.config.ref.megatron.use_distributed_optimizer,\n        )\n        ref_module = make_megatron_module(\n            wrap_config=wrap_config,\n            tf_config=self.tf_config,\n            hf_config=self.hf_config,\n            bridge=self.bridge,\n            override_model_config=override_model_config,\n        )\n        if self.config.ref.load_weight:  # should align with the actor:\n            assert self.config.actor.load_weight == self.config.ref.load_weight\n            print(\"load ref weight start\")\n            if self.config.ref.megatron.use_dist_checkpointing:\n                load_mcore_dist_weights(\n                    ref_module, self.config.ref.megatron.dist_checkpointing_path, is_value_model=False\n                )\n            else:\n                if self.bridge is not None:\n                    local_model_path = get_hf_model_path(self.config)\n                    self.bridge.load_weights(ref_module, local_model_path)\n                else:\n                    load_megatron_gptmodel_weights(\n                        self.config, self.hf_config, ref_module, params_dtype=self.dtype, is_value_model=False\n                    )\n        log_gpu_memory_usage(\"After ref module init\", logger=logger)\n        return ref_module, self.hf_config\n\n    def init_model(self):\n        import_external_libs(self.config.model.external_lib)\n\n        override_model_config = self.config.model.override_config\n        override_transformer_config = self.config.ref.megatron.override_transformer_config\n        \n        if not override_transformer_config:\n            override_transformer_config = OmegaConf.create()\n        \n        self.param_dtype = torch.bfloat16\n        self.dtype = PrecisionType.to_dtype(self.param_dtype)\n\n        log_gpu_memory_usage(\"Before init ref model\", logger=logger)\n        self.ref_module, self.ref_model_config = self._build_ref_model(\n            model_path=self.config.model.path,\n            override_model_config=override_model_config,\n            override_transformer_config=override_transformer_config,\n        )\n        log_gpu_memory_usage(\"After init ref model\", logger=logger)\n        self.ref_policy = MegatronPPOActor(\n            config=self.config.ref,\n            model_config=self.ref_model_config,\n            hf_config=self.hf_config,\n            tf_config=self.tf_config,\n            actor_module=self.ref_module,\n            actor_optimizer=None,\n        )\n        if self._ref_is_offload_param:\n            offload_megatron_model_to_cpu(self.ref_module)\n            log_gpu_memory_usage(\"After offload ref params during init\", logger=logger)\n\n        self.flops_counter = FlopsCounter(self.ref_model_config)\n        get_torch_device().empty_cache()\n        log_gpu_memory_usage(\"After finish ref model init\", logger=logger)\n\n    @GPUMemoryLogger(role=\"compute_ref_log_prob\", logger=logger)\n    def compute_ref_log_prob(self, data: TensorDict):\n        if self._ref_is_offload_param:\n            load_megatron_model_to_gpu(self.ref_module, load_grad=False)\n            log_gpu_memory_usage(\"After load ref params and grad during compute_ref_log_prob\", logger=logger)\n        micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu\n        data[\"micro_batch_size\"] = NonTensorData(micro_batch_size)\n        data[\"max_token_len\"] = NonTensorData(self.config.ref.log_prob_max_token_len_per_gpu)\n        data[\"use_dynamic_bsz\"] = NonTensorData(self.config.ref.log_prob_use_dynamic_bsz)\n        data[\"temperature\"] = NonTensorData(self.config.rollout.temperature)\n        data = data.to(get_device_id())\n\n        with Timer(name=\"compute_ref_log_prob\", logger=None) as timer:\n            output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False)\n        delta_time = timer.last\n\n        # update metrics\n        metrics = {}\n        global_num_tokens = data[\"global_token_num\"]\n        estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)\n        metrics[\"perf/mfu/ref\"] = estimated_flops / promised_flops\n        metrics[\"perf/delta_time/ref\"] = delta_time\n        data[\"metrics\"] = NonTensorData(metrics)\n        data[\"ref_log_prob\"] = output\n        data = data.to(\"cpu\")\n        if self._ref_is_offload_param:\n            offload_megatron_model_to_cpu(self.ref_module)\n            log_gpu_memory_usage(\"After offload ref params and grad during compute_ref_log_prob\", logger=logger)\n        return data\n\n\nclass AsyncRolloutWorker(RolloutWorker):\n    def _build_rollout(self, trust_remote_code=False):\n        rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code)\n\n        # NOTE: rollout is not actually initialized here, it's deferred\n        # to be initialized by AsyncvLLMServer.\n\n        self.vllm_tp_size = self.config.rollout.tensor_model_parallel_size\n        self.vllm_dp_rank = int(os.environ[\"RANK\"]) // self.vllm_tp_size\n        self.vllm_tp_rank = int(os.environ[\"RANK\"]) % self.vllm_tp_size\n\n        # used for sleep/wake_up\n        rollout.sharding_manager = rollout_sharding_manager\n\n        return rollout, rollout_sharding_manager\n\n    def execute_method(self, method: Union[str, bytes], *args, **kwargs):\n        \"\"\"Called by ExternalRayDistributedExecutor collective_rpc.\"\"\"\n        if self.vllm_tp_rank == 0 and method != \"execute_model\":\n            print(f\"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] execute_method: {method if isinstance(method, str) else 'Callable'}\")\n        return self.rollout.execute_method(method, *args, **kwargs)\n\n    async def chat_completion(self, json_request):\n        ret = await self.rollout.chat_completion(json_request)\n        return ret\n\n    async def wake_up(self):\n        await self.rollout.wake_up()\n        # return something to block the caller\n        return True\n\n    async def sleep(self):\n        await self.rollout.sleep()\n        # return something to block the caller\n        return True\n\n    def set_rollout_sharding_manager(self, sharding_manager):\n        super().set_rollout_sharding_manager(sharding_manager)\n        self.rollout.sharding_manager = sharding_manager\n"
  },
  {
    "path": "siirl/engine/reward_manager/__init__.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .dapo import DAPORewardManager\nfrom .naive import NaiveRewardManager\nfrom .embodied import EmbodiedRewardManager\nfrom .parallel import ParallelRewardManager\n# Lazy import for embodied reward manager to avoid loading embodied dependencies for LLM/VLM tasks\ndef __getattr__(name):\n    if name == \"EmbodiedRewardManager\":\n        from .embodied import EmbodiedRewardManager\n        return EmbodiedRewardManager\n    raise AttributeError(f\"module '{__name__}' has no attribute '{name}'\")\n\n__all__ = [\"DAPORewardManager\", \"NaiveRewardManager\", \"EmbodiedRewardManager\",\"ParallelRewardManager\"]\n"
  },
  {
    "path": "siirl/engine/reward_manager/dapo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom collections import defaultdict\n\nimport os\nimport torch\nfrom tensordict import TensorDict\n\nfrom loguru import logger\nfrom siirl.utils.reward_score import default_compute_score\n\n\nclass DAPORewardManager:\n    \"\"\"The reward manager.\"\"\"\n\n    def __init__(\n        self,\n        tokenizer,\n        num_examine,\n        compute_score=None,\n        reward_fn_key=\"data_source\",\n        max_resp_len=None,\n        overlong_buffer_cfg=None,\n    ) -> None:\n        self.tokenizer = tokenizer\n        self.num_examine = num_examine  # the number of batches of decoded responses to print to the console\n        self.compute_score = compute_score or default_compute_score\n        self.reward_fn_key = reward_fn_key\n        self.overlong_buffer_cfg = overlong_buffer_cfg\n        self.max_resp_len = max_resp_len\n        self.rank = int(os.environ.get(\"RANK\", \"0\"))\n\n        if self.overlong_buffer_cfg is not None:\n            assert self.max_resp_len is not None, f\"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None\"\n\n    def __call__(self, data: TensorDict, return_dict: bool = False):\n        \"\"\"We will expand this function gradually based on the available datasets\"\"\"\n\n        # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn\n        if \"rm_scores\" in data.keys():\n            if return_dict:\n                return {\"reward_tensor\": data[\"rm_scores\"]}\n            else:\n                return data[\"rm_scores\"]\n\n        reward_tensor = torch.zeros_like(data[\"responses\"], dtype=torch.float32)\n        reward_extra_info = defaultdict(list)\n\n        already_print_data_sources = {}\n\n        for i in range(len(data)):\n            data_item = data[i]  # TensorDictItem\n\n            prompt_ids = data_item[\"prompts\"]\n\n            prompt_length = prompt_ids.shape[-1]\n\n            valid_prompt_length = data_item[\"attention_mask\"][:prompt_length].sum()\n            valid_prompt_ids = prompt_ids[-valid_prompt_length:]\n\n            response_ids = data_item[\"responses\"]\n            valid_response_length = data_item[\"attention_mask\"][prompt_length:].sum()\n            valid_response_ids = response_ids[:valid_response_length]\n\n            # decode\n            prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)\n            response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)\n            eos_token = self.tokenizer.eos_token\n            if response_str.endswith(eos_token):\n                response_str = response_str[: -len(eos_token)]\n\n            ground_truth = data[\"reward_model\"][i][\"ground_truth\"]\n\n            data_source = data[self.reward_fn_key][i]\n\n            extra_info =  data[\"extra_info\"][i] if \"extra_info\" in data else None\n\n            result = self.compute_score(\n                data_source=data_source,\n                solution_str=response_str,\n                ground_truth=ground_truth,\n                extra_info=extra_info,\n            )\n\n            score: float\n            if isinstance(result, dict):\n                score = result[\"score\"]\n                # Store the information including original reward\n                for key, value in result.items():\n                    reward_extra_info[key].append(value)\n            else:\n                score = result\n\n            reward = score\n\n            if self.overlong_buffer_cfg.enable:\n                overlong_buffer_len = self.overlong_buffer_cfg.len\n                expected_len = self.max_resp_len - overlong_buffer_len\n                exceed_len = valid_response_length - expected_len\n                overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor\n                overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)\n                reward += overlong_reward\n                if self.overlong_buffer_cfg.log:\n                    reward_extra_info[\"overlong_reward\"].append(overlong_reward)\n                    reward_extra_info[\"overlong\"].append(overlong_reward < 0)\n\n            reward_tensor[i, valid_response_length - 1] = reward\n\n            if data_source not in already_print_data_sources:\n                already_print_data_sources[data_source] = 0\n\n            if self.rank == 0 and already_print_data_sources[data_source] < self.num_examine:\n                already_print_data_sources[data_source] += 1\n                logger.info(f\"[prompt] {prompt_str}\")\n                logger.info(f\"[response] {response_str}\")\n                logger.info(f\"[ground_truth] {ground_truth}\")\n                if isinstance(result, dict):\n                    for key, value in result.items():\n                        logger.info(f\"[{key}] {value}\")\n                else:\n                    logger.info(f\"[score] {score}\")\n\n        if return_dict:\n            return {\n                \"reward_tensor\": reward_tensor,\n                \"reward_extra_info\": reward_extra_info,\n            }\n        else:\n            return reward_tensor\n"
  },
  {
    "path": "siirl/engine/reward_manager/embodied.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\nfrom tensordict import TensorDict\nimport numpy as np\nimport torch\nfrom loguru import logger\nfrom transformers import PreTrainedTokenizer\n\n\n\nclass EmbodiedRewardManager:\n    \"\"\"\n    Manages the reward calculation process for Embodied AI tasks.\n\n    This class acts as an orchestrator. It receives the framework-specific\n    `TensorDict` object and delegates the complex reward computation to an\n    injected `compute_score` function.\n    \"\"\"\n\n    def __init__(\n        self,\n        tokenizer: Optional[PreTrainedTokenizer] = None,\n        num_examine: int = 1,\n        compute_score: Optional[Callable] = None,\n        reward_fn_key: str = \"data_source\",\n        **reward_kwargs,\n    ):\n        \"\"\"\n        Initializes the reward manager.\n\n        Args:\n            tokenizer: The tokenizer, if needed for any text processing.\n            num_examine: The number of reward examples to log for debugging.\n            compute_score: The function to call for calculating reward scores.\n                           Defaults to the compute_embodied_reward.\n            reward_fn_key: The key to identify the data source.\n            **reward_kwargs: A dictionary for additional parameters like\n                             `action_token_len` and `reward_coef`.\n        \"\"\"\n        self.tokenizer = tokenizer\n        self.num_examine = num_examine\n\n        # Import default compute_score if not provided\n        if compute_score is None:\n            try:\n                from siirl.utils.reward_score.embodied import compute_embodied_reward\n                self.compute_score = compute_embodied_reward\n            except ImportError:\n                logger.warning(\n                    \"Could not import compute_embodied_reward. \"\n                    \"Please provide compute_score function or ensure embodied reward module exists.\"\n                )\n                self.compute_score = None\n        else:\n            self.compute_score = compute_score\n\n        self.reward_fn_key = reward_fn_key\n        self.rank = int(os.environ.get(\"RANK\", \"0\"))\n        self.print_count = 0\n\n        # Extract specific parameters from kwargs with safe defaults.\n        self.action_token_len = reward_kwargs.get(\"action_token_len\", 7)\n        self.reward_coef = reward_kwargs.get(\"reward_coef\", 1.0)\n\n    def __call__(self, data: TensorDict, return_dict: bool = False) -> Union[Dict[str, Any], Tuple[Dict[str, torch.Tensor], Dict[str, float]]]:\n        \"\"\"\n        Calculates and returns the reward tensors and metrics for a given data batch.\n        \n        Args:\n            data: TensorDict containing batch information\n            return_dict: If True, returns format compatible with compute_reward function\n                        If False, returns format compatible direct call\n        \n        Returns:\n            If return_dict=True: {\"reward_tensor\": tensor, \"reward_extra_info\": dict}\n            If return_dict=False: (reward_tensor_dict, reward_metrics) tuple\n        \"\"\"\n        batch_size = data[\"responses\"].shape[0]\n\n        # --- Step 1: Delegate the core reward calculation ---\n        if self.compute_score is None:\n            # Return zero rewards as fallback\n            verifier_scores = [0.0] * batch_size\n            format_scores = [1.0] * batch_size\n            scores_info = [{\"score\": 0.0, \"format_correctness\": 1.0, \"is_success\": False} for _ in range(batch_size)]\n        else:\n            scores_info = self.compute_score(batch_data=data)\n            verifier_scores = [info[\"score\"] for info in scores_info]\n            format_scores = [info.get(\"format_correctness\", 1.0) for info in scores_info]\n\n        # --- Step 3: Log debug examples (on rank 0 only) ---\n        if self.rank == 0 and self.print_count < self.num_examine:\n            logger.info(\"--- EmbodiedRewardManager Reward Calculation Example ---\")\n            for i in range(min(batch_size, 2)):\n                info = scores_info[i]\n                logger.info(f\"Sample {i} | Task: {info.get('task_name', 'N/A')}\")\n                logger.info(f\"  - Success: {info.get('is_success')}\")\n                if not info.get(\"is_success\"):\n                    dist = info.get(\"normalized_distance\", \"N/A\")\n                    if isinstance(dist, float):\n                        logger.info(f\"  - Normalized Distance: {dist:.4f}\")\n                    else:\n                        logger.info(f\"  - Normalized Distance: {dist}\")\n                logger.info(f\"  -> Final Score: {info.get('score', 0.0):.4f}\")\n            self.print_count += 1\n\n        # --- Step 4: Populate the reward tensor at the final timestep ---\n        # The reward is applied as a terminal reward at the end of the action sequence.\n        \n        verifier_rewards = torch.zeros_like(data[\"responses\"], dtype=torch.float32)\n        \n        verifier_rewards = verifier_rewards.view(batch_size, -1)\n\n        valid_response_length = data[\"finish_step\"] * self.action_token_len\n\n        for i in range(batch_size):\n            last_step_idx = valid_response_length[i] - 1\n            if last_step_idx >= 0:\n                verifier_rewards[i, last_step_idx] = verifier_scores[i]\n\n        # --- Step 5: Aggregate final rewards and metrics ---\n        reward_tensor_dict = {\"gt_scores\": verifier_rewards}\n        reward_metrics = {}\n\n        final_reward_tensor = torch.zeros_like(verifier_rewards)\n        if self.reward_coef != 0:\n            final_reward_tensor += self.reward_coef * reward_tensor_dict[\"gt_scores\"]\n\n            # Add all relevant metrics to the dictionary for logging.\n            reward_metrics[\"verifier_mean\"] = torch.tensor(verifier_scores).mean().item()\n            reward_metrics[\"format_correctness_mean\"] = torch.tensor(format_scores).mean().item()\n\n        reward_tensor_dict[\"all\"] = final_reward_tensor\n        reward_metrics[\"reward_all\"] = final_reward_tensor.sum(dim=-1).mean().item()\n\n        # Return format based on return_dict flag\n        if return_dict:\n            # Format for compute_reward function (scheduler.reward.compute_reward)\n            # Return per-sample format to match NaiveRewardManager/BatchRewardManager standard\n            reward_extra_info = {\n                \"verifier_score\": verifier_scores,      # Per-sample scores (already a list)\n                \"format_correctness\": format_scores,    # Per-sample format correctness (already a list)\n            }\n            return {\n                \"reward_tensor\": reward_tensor_dict[\"all\"],\n                \"reward_extra_info\": reward_extra_info\n            }\n        else:\n            return reward_tensor_dict, reward_metrics\n\n    def verify(self, data: TensorDict) -> Tuple[List[float], Dict[str, float], Dict[str, float], Dict[str, float]]:\n        \"\"\"\n        Verify and compute rewards for validation.\n        \n        This method directly reads the 'complete' field from data.batch.\n        \n        Args:\n            data: TensorDict containing batch information with embodied task data\n            \n        Returns:\n            tuple: (verifier_scores, reward_metrics, format_metrics, reward_format_metrics)\n                - verifier_scores: List[float] - Binary success (0/1) for each sample\n                - reward_metrics: Dict[str, float] - Aggregated metrics\n                - format_metrics: Dict[str, float] - Format correctness (always 1.0)\n                - reward_format_metrics: Dict[str, float] - Same as reward_metrics\n        \"\"\"\n        # Step 1: Read complete field directly from batch\n        completes = data['complete'].tolist()\n        batch_size = data['responses'].size(0)\n        assert len(completes) == batch_size\n        \n        # Convert boolean to float (0.0 or 1.0)\n        score = [float(item) for item in completes]\n        \n        # Step 2: Store to batch tensors\n        device = data['responses'].device\n        acc_tensor = torch.tensor(score, dtype=torch.float32, device=device)\n        format_tensor = torch.ones(batch_size, dtype=torch.float32, device=device)\n        \n        data['acc'] = acc_tensor\n        data['format_correctness'] = format_tensor\n        \n        # Step 3: Compute aggregated metrics\n        success_rate = acc_tensor.mean().item()\n        \n        reward_metrics = {'all': success_rate}\n        format_metrics = {'all': 1.0}  # Always 1.0, no need to compute\n        reward_format_metrics = {'all': success_rate}\n        \n        # Return the 4-tuple expected by validation_mixin.py\n        return score, reward_metrics, format_metrics, reward_format_metrics\n\n"
  },
  {
    "path": "siirl/engine/reward_manager/naive.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom collections import defaultdict\nfrom tensordict import TensorDict\nimport torch\nfrom loguru import logger\nfrom torch import distributed as dist\n\nfrom siirl.utils.reward_score import default_compute_score\n\n\nclass NaiveRewardManager:\n    \"\"\"The reward manager.\"\"\"\n\n    def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key=\"data_source\", **reward_kwargs) -> None:\n        self.tokenizer = tokenizer\n        self.num_examine = num_examine  # the number of batches of decoded responses to print to the console\n        self.compute_score = compute_score or default_compute_score\n        self.reward_fn_key = reward_fn_key\n        self.rank = int(os.environ.get(\"RANK\", \"0\"))\n\n    def __call__(self, data: TensorDict, return_dict=False):\n        \"\"\"We will expand this function gradually based on the available datasets\"\"\"\n\n        # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn\n        if \"rm_scores\" in data.keys():\n            if return_dict:\n                return {\"reward_tensor\": data[\"rm_scores\"]}\n            else:\n                return data[\"rm_scores\"]\n        reward_tensor = torch.zeros_like(data[\"responses\"], dtype=torch.float32)\n        reward_extra_info = defaultdict(list)\n\n        already_print_data_sources = {}\n\n        for i in range(len(data)):\n            data_item = data[i] \n\n            prompt_ids = data_item[\"prompts\"]\n\n            prompt_length = prompt_ids.shape[-1]\n\n            valid_prompt_length = data_item[\"attention_mask\"][:prompt_length].sum()\n            valid_prompt_ids = prompt_ids[-valid_prompt_length:]\n\n            response_ids = data_item[\"responses\"]\n            valid_response_length = data_item[\"attention_mask\"][prompt_length:].sum()\n            valid_response_ids = response_ids[:valid_response_length]\n\n            # decode\n            prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)\n            response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)\n\n            from siirl.models.transformers.internvl import IMG_CONTEXT_TOKEN\n\n            if IMG_CONTEXT_TOKEN in prompt_str:\n                prompt_str = prompt_str.replace(IMG_CONTEXT_TOKEN, \"\")\n            ground_truth = data[\"reward_model\"][i][\"ground_truth\"]\n\n            data_source = data[self.reward_fn_key][i]\n\n            extra_info =  data[\"extra_info\"][i] if \"extra_info\" in data else None\n            \n            score = self.compute_score(\n                data_source=data_source,\n                solution_str=response_str,\n                ground_truth=ground_truth,\n                extra_info=extra_info,\n            )\n\n            if isinstance(score, dict):\n                reward = score[\"score\"]\n                # Store the information including original reward\n                for key, value in score.items():\n                    reward_extra_info[key].append(value)\n            else:\n                reward = score\n\n            reward_tensor[i, valid_response_length - 1] = reward\n\n            if data_source not in already_print_data_sources:\n                already_print_data_sources[data_source] = 0\n\n            if self.rank == 0 and already_print_data_sources[data_source] < self.num_examine:\n                already_print_data_sources[data_source] += 1\n                logger.info(f\"rank:{dist.get_rank()}\")\n                logger.info(f\"[prompt] {prompt_str}\")\n                logger.info(f\"[response] {response_str}\")\n                logger.info(f\"[ground_truth] {ground_truth}\")\n                if isinstance(score, dict):\n                    for key, value in score.items():\n                        logger.info(f\"[{key}] {value}\")\n                else:\n                    logger.info(f\"[score] {score}\")\n\n        if return_dict:\n            return {\n                \"reward_tensor\": reward_tensor,\n                \"reward_extra_info\": reward_extra_info,\n            }\n        else:\n            return reward_tensor\n"
  },
  {
    "path": "siirl/engine/reward_manager/parallel.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom siirl.utils.reward_score import _default_compute_score\nfrom siirl.models.transformers.internvl import IMG_CONTEXT_TOKEN\n\nimport torch\nimport os\nimport multiprocessing as mp\nfrom functools import partial\nimport multiprocessing.dummy as mp_dummy\nfrom functools import partial\nfrom tensordict import TensorDict\n\nclass ParallelRewardManager:\n    \"\"\"The reward manager.\"\"\"\n\n    def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key=None) -> None:\n        self.tokenizer = tokenizer\n        self.num_examine = num_examine  # the number of batches of decoded responses to print to the console\n        self.compute_score = compute_score or _default_compute_score\n        self.rank = int(os.environ.get(\"RANK\", \"0\"))\n        self.reward_fn_key = reward_fn_key\n    def _process_single_item(self, data_item):\n        prompt_length = data_item.batch[\"prompts\"].shape[-1]\n        valid_response_length = data_item.batch[\"attention_mask\"][prompt_length:].sum()\n        response_ids = data_item.batch[\"responses\"]\n        item = {\n            \"valid_response_ids\": response_ids[:valid_response_length].cpu(),\n            \"ground_truth\": data_item.non_tensor_batch[\"reward_model\"][\"ground_truth\"],\n            \"data_source\": data_item.non_tensor_batch[self.reward_fn_key],\n            \"extra_info\": data_item.non_tensor_batch.get(\"extra_info\", None),\n        }\n        return item\n\n    def _compute_score(self, item):\n        response_str = self.tokenizer.decode(item[\"valid_response_ids\"])\n        return self.compute_score(\n            data_source=item[\"data_source\"],\n            solution_str=response_str,\n            ground_truth=item[\"ground_truth\"],\n            extra_info=item[\"extra_info\"],\n        )\n\n    def verify(self, data):\n        with mp_dummy.Pool(processes=mp.cpu_count() // 2) as pool:\n            items = [self._process_single_item(data[i]) for i in range(len(data))]\n            scores = pool.map(partial(self._compute_score), items)\n            data[\"acc\"] = torch.tensor(scores, dtype=torch.float32, device=data[0][\"prompts\"].device)\n        return scores\n\n    def __call__(self, data: TensorDict, return_dict: bool = True):\n        \"\"\"We will expand this function gradually based on the available datasets\"\"\"\n\n        # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn\n        if \"rm_scores\" in data.keys():\n            return data[\"rm_scores\"]\n\n        scores = self.verify(data)\n        reward_tensor = torch.zeros_like(data[\"responses\"], dtype=torch.float32)\n\n        already_print_data_sources = {}\n\n        for i in range(len(data)):\n            data_item = data[i]\n            prompt_length = data_item[\"prompts\"].shape[-1]\n            valid_response_length = data_item[\"attention_mask\"][prompt_length:].sum()\n            reward_tensor[i, valid_response_length - 1] = scores[i]\n\n            data_source = data[\"data_source\"][i]\n            if data_source not in already_print_data_sources:\n                already_print_data_sources[data_source] = 0\n\n            if self.rank == 0 and already_print_data_sources[data_source] < self.num_examine:\n                already_print_data_sources[data_source] += 1\n                valid_prompt_length = data_item[\"attention_mask\"][:prompt_length].sum()\n                prompt_str = self.tokenizer.decode(data_item[\"prompts\"][-valid_prompt_length:])\n                response_str = self.tokenizer.decode(data_item[\"responses\"][:valid_response_length])\n                print(\"[prompt]\", prompt_str.replace(IMG_CONTEXT_TOKEN, \"\"))\n                print(\"[response]\", response_str)\n                print(\"[ground_truth]\", data[\"reward_model\"][i][\"ground_truth\"])\n                print(\"[score]\", scores[i])\n        if return_dict:\n            return {\n                    \"reward_tensor\": reward_tensor,\n                    \"reward_extra_info\": {},\n                }\n        return reward_tensor\n"
  },
  {
    "path": "siirl/engine/reward_model/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base import BasePPORewardModel\n\n__all__ = [\"BasePPORewardModel\"]\n"
  },
  {
    "path": "siirl/engine/reward_model/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe base class for reward model\n\"\"\"\n\nfrom abc import ABC, abstractmethod\n\nfrom tensordict import TensorDict\n\n\nclass BasePPORewardModel(ABC):\n    def __init__(self, config):\n        self.config = config\n\n    @abstractmethod\n    def compute_reward(self, data: TensorDict) -> TensorDict:\n        \"\"\"Computing reward given input_ids. The transformers should output a tensor with shape\n           [batch_size, sequence_length], and the value at [EOS] mask should be gathered.\n\n        Args:\n            data: must contain keys \"input_ids\", \"attention_mask\" and \"position_ids\".\n                - input_ids: [batch_size, sequence_length]\n                - attention_mask: [batch_size, sequence_length]\n                - position_ids: [batch_size, sequence_length]\n\n        Returns: a data pass protocol containing \"reward\". Only the [EOS] position contains the reward.\n            Other position should have zero reward. Note that this may change in the future if we use\n            dense reward. So, we leave the interface for general case.\n            - reward: [batch_size, sequence_length].\n\n        \"\"\"\n        pass\n"
  },
  {
    "path": "siirl/engine/reward_model/megatron/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .reward_model import MegatronRewardModel\n\n__all__ = [\"MegatronRewardModel\"]\n"
  },
  {
    "path": "siirl/engine/reward_model/megatron/reward_model.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nMegatron Reward Model.\n\"\"\"\n\nimport itertools\n\nimport torch\nimport torch.distributed\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core.pipeline_parallel import get_forward_backward_func\nfrom tensordict import TensorDict\n\nfrom siirl.utils.extras.device import get_device_id, get_device_name, get_torch_device\nfrom siirl.utils.megatron.pipeline_parallel import make_batch_generator\nfrom siirl.utils.model_utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches\nfrom siirl.utils.model_utils.torch_functional import broadcast_dict_tensor, pad_sequence_to_length\nfrom siirl.engine.reward_model.base import BasePPORewardModel\n\n\nclass MegatronRewardModel(BasePPORewardModel):\n    def __init__(\n        self,\n        config,\n        model_config,\n        reward_model_module: torch.nn.ModuleList,\n        hf_config,\n        tf_config,\n        sft_tokenizer=None,\n        rm_tokenizer=None,\n    ):\n        self.config = config\n        self.reward_model_module = reward_model_module\n        self.hf_config = hf_config\n        self.tf_config = tf_config\n        self.model_config = model_config\n        self.device = \"cuda\"\n        self.sft_tokenizer = sft_tokenizer\n        self.rm_tokenizer = rm_tokenizer\n        self.use_different_tokenizer = rm_tokenizer is not None\n\n        print(f\"MegatronRewardModel.config: {self.config}\")\n\n        if self.config.megatron.param_offload:\n            self.offload_params_to_cpu()\n\n    def re_encode_by_rm_tokenizer(self, data: TensorDict) -> TensorDict:\n        assert self.use_different_tokenizer, \"re-encode need rm tokenizer not be None!\"\n        # need to use rm tokenizer to re-generate input_ids, attention_mask and position_ids\n        # 1. remove pad for each sequence\n        # 2. decode by sft_tokenizer, remove sft system prompts\n        # 3. encode by rm_tokenizer with rm system prompts, get rm_input_ids\n        # 4. generate attention_mask and position_ids\n        input_ids = data.batch[\"input_ids\"]  # (bs, seq_len)\n        attention_mask = data.batch[\"attention_mask\"]\n        position_ids = data.batch[\"position_ids\"]\n        ori_values = {\"input_ids\": input_ids, \"attention_mask\": attention_mask, \"position_ids\": position_ids}\n        _, ori_seqlen = input_ids.size(0), input_ids.size(1)\n        input_ids_for_rm = []\n        attention_mask_for_rm = []\n        position_ids_for_rm = []\n        print_decode = True\n        ori_seqlen = ori_seqlen + 128\n        for id, mask in zip(input_ids, attention_mask):\n            # 1. remove pad for each sequence\n            non_zero_indices = torch.nonzero(mask).view(-1)\n            begin_pos, end_pos = non_zero_indices[0].item(), non_zero_indices[-1].item()\n            valid_id = id[begin_pos : end_pos + 1]\n            # 2. decode by sft_tokenizer, remove sft system prompts\n            decode_result = self.sft_tokenizer.decode(valid_id)\n            # workaround\n            decode_with_rm_chat = decode_result.replace(\"<|user|>\\n\", \"[INST] \").replace(\"</s>\\n<|assistant|>\\n\", \" [/INST]\").replace(\"</s> \\n<|assistant|>\\n\", \" [/INST]\") + \"</s>\"\n            if print_decode and torch.distributed.get_rank() == 0:\n                # only print first decode result\n                print(\n                    f\"device {get_device_id()}: sft decode result:\\n{decode_result}\\n \\\n                        \\ndevice {get_device_id()}: sft decode result with \\\n                        rm chat template:\\n{decode_with_rm_chat}\\n\\n\"\n                )\n                print_decode = False\n            # 3. encode by rm_tokenizer\n            rm_input_ids = self.rm_tokenizer(decode_with_rm_chat, return_tensors=\"pt\")[\"input_ids\"][0].to(input_ids.device)\n            # 4. generate attention_mask and position_ids\n            rm_attention_mask = torch.ones_like(rm_input_ids, device=input_ids.device)\n            cur_seqlen = rm_input_ids.shape[-1]\n            # NOTE(gh): the later reward compute will process the shape (bs, seqlen_pad_128)\n            if cur_seqlen > ori_seqlen:\n                print(f\"warninig: rm encode seqlen {cur_seqlen} > sft encode seqlen {ori_seqlen}\")\n                rm_input_ids = rm_input_ids[:ori_seqlen]\n                rm_attention_mask = rm_attention_mask[:ori_seqlen]\n            else:\n                # right padding\n                rm_input_ids = pad_sequence_to_length(rm_input_ids, ori_seqlen, self.rm_tokenizer.pad_token_id)\n                rm_attention_mask = pad_sequence_to_length(rm_attention_mask, ori_seqlen, 0)\n            rm_position_ids = torch.arange(0, ori_seqlen, device=input_ids.device)\n            input_ids_for_rm.append(torch.unsqueeze(rm_input_ids, dim=0))\n            attention_mask_for_rm.append(torch.unsqueeze(rm_attention_mask, dim=0))\n            position_ids_for_rm.append(torch.unsqueeze(rm_position_ids, dim=0))\n        input_ids_for_rm = torch.cat(input_ids_for_rm, dim=0)\n        attention_mask_for_rm = torch.cat(attention_mask_for_rm, dim=0)\n        position_ids_for_rm = torch.cat(position_ids_for_rm, dim=0)\n\n        # (bs, seqlen) will not change, but input_ids, attention_mask and position_ids will change\n        # NOTE(gh): need to replace into origin values after compute reward!\n        data.batch[\"input_ids\"] = input_ids_for_rm\n        data.batch[\"attention_mask\"] = attention_mask_for_rm\n        data.batch[\"position_ids\"] = position_ids_for_rm\n\n        return data, ori_values\n\n    @torch.no_grad()\n    def compute_reward(self, data: TensorDict) -> TensorDict:\n        if self.config.megatron.param_offload:\n            self.load_params_to_cuda()\n\n        if self.use_different_tokenizer:\n            data, ori_values = self.re_encode_by_rm_tokenizer(data)\n\n        input_ids = data.batch[\"input_ids\"]  # (bs, seq_len')\n        attention_mask = data.batch[\"attention_mask\"]\n        position_ids = data.batch[\"position_ids\"]\n        use_dynamic_bsz = data.meta_info.get(\"use_dynamic_bsz\", False)\n        micro_batch_size = data.meta_info.get(\"micro_batch_size\", None)\n        max_token_len = data.meta_info.get(\"max_token_len\", None)\n        assert micro_batch_size is not None, \"micro batch size is needed for forward compute\"\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"use_dynamic_bsz is True, but max_token_len is None!\"\n            max_token_len = max_token_len * self.config.megatron.context_parallel_size\n\n        responses = data.batch[\"responses\"]\n        batch_size = responses.size(0)\n        response_length = responses.size(1)\n\n        with torch.no_grad():\n            output = self.forward_batch(data, use_dynamic_bsz=use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len)\n            if mpu.is_pipeline_last_stage(ignore_virtual=True):\n                logits = torch.cat(output[\"output\"], dim=0)\n                if use_dynamic_bsz:\n                    indices = output[\"indices\"]\n                    indices = list(itertools.chain.from_iterable(indices))\n                    assert len(indices) == logits.size(0), f\"{len(indices)} vs. {logits.size()}\"\n                    revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n                    logits = logits[revert_indices]\n            else:\n                logits = torch.empty(\n                    (input_ids.shape[0], input_ids.shape[1]),\n                    device=input_ids.device,\n                )\n            logits = logits.to(torch.float32)\n\n            # broadcast across pp ranks\n            torch.distributed.broadcast(\n                tensor=logits,\n                src=mpu.get_pipeline_model_parallel_last_rank(),\n                group=mpu.get_pipeline_model_parallel_group(),\n                async_op=False,\n            )\n\n        # (bs, seqlen', hidden_size) -> (bs, seqlen', 1) -> (bs, seqlen')\n        token_level_rewards = logits\n        # find the last token reward\n        ends = attention_mask.cumsum(dim=-1).argmax(dim=-1).view(-1, 1)  # (bs, 1)\n        rewards = torch.gather(token_level_rewards, dim=1, index=ends)  # (bs, 1)\n\n        if self.use_different_tokenizer:\n            data.batch.update(ori_values)\n            input_ids = ori_values[\"input_ids\"]\n            attention_mask = ori_values[\"attention_mask\"]\n            position_ids = ori_values[\"position_ids\"]\n\n        token_level_rewards = rewards.expand(attention_mask.shape[0], attention_mask.shape[1])  # (bs, ori_seqlen)\n\n        # assign last valid token reward to ori position\n        eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)  # (bs,)\n        eos_mask = torch.zeros_like(attention_mask)\n        eos_mask[torch.arange(batch_size), eos_mask_idx] = 1.0\n\n        token_level_rewards = token_level_rewards * eos_mask\n        token_level_rewards = token_level_rewards[:, -response_length:]\n\n        if self.config.megatron.param_offload:\n            self.offload_params_to_cpu()\n        else:\n            # add empty cache after each compute\n            get_torch_device().empty_cache()\n\n        batch = TensorDict({\"rm_scores\": token_level_rewards}, batch_size=input_ids.shape[0])\n\n        return TensorDict(batch=batch)\n\n    def forward_batch(self, data: TensorDict, use_dynamic_bsz=False, micro_batch_size=None, max_token_len=None):\n        \"\"\"\n        We assume:\n        - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input\n        - The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled\n        \"\"\"\n        # broadcast from last pp rank to all other pp ranks\n        # TODO: actually, we just need to control the sampling order.\n        mini_batch = data\n        mini_batch.batch = mini_batch.batch.contiguous()\n        broadcast_dict_tensor(mini_batch.batch, src=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group())\n        \n        # broadcast from cp rank 0 to all other cp ranks to ensure same data across CP group\n        cp_size = mpu.get_context_parallel_world_size()\n        if cp_size > 1:\n            cp_group = mpu.get_context_parallel_group()\n            cp_group_ranks = torch.distributed.get_process_group_ranks(cp_group)\n            src_rank = cp_group_ranks[0]  # cp_rank=0 in this group\n            broadcast_dict_tensor(\n                mini_batch.batch,\n                src=src_rank,\n                group=cp_group,\n            )\n\n        mini_batch.batch[\"attention_mask\"] = mini_batch.batch[\"attention_mask\"].to(bool)\n\n        indices = None\n        if use_dynamic_bsz:\n            assert max_token_len is not None, \"max_token_len must be set when use_dynamic_bsz is True\"\n            vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n            if vpp_size is not None and vpp_size > 1:\n                microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage\n                micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, num_batches_divided_by=microbatch_group_size_per_vp_stage, max_token_len=max_token_len)\n                assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, f\"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage {microbatch_group_size_per_vp_stage} for megatron backend\"\n            else:\n                micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len)\n            total_seqlen = max_token_len\n        else:\n            assert micro_batch_size is not None, \"micro_batch_size is needed to be passed in when not using dynamic batch size\"\n            micro_batches = mini_batch.batch.split(micro_batch_size)\n            seq_len = micro_batches[0][\"input_ids\"].shape[1]\n            total_seqlen = micro_batch_size * seq_len\n        n_micro_batch = len(micro_batches)\n\n        # compute input shapes for pp stages\n        forward_backward_func = get_forward_backward_func()\n\n        def loss_func(output):\n            return torch.tensor(1.0, device=output.device), output\n\n        def forward_step(batch_iter, model):\n            batch = next(batch_iter)\n            input_ids = batch[\"input_ids\"]\n            attention_mask = batch[\"attention_mask\"]\n            position_ids = batch[\"position_ids\"]\n            from siirl.models.mcore import get_mcore_forward_fn\n\n            forward_fn = get_mcore_forward_fn(self.hf_config)\n\n            output = forward_fn(\n                model,\n                input_ids,\n                attention_mask,\n                position_ids,\n                sequence_parallel=self.tf_config.sequence_parallel,\n                value_model=True,\n            )\n\n            return output, loss_func\n\n        # batch should be a list of batches inside micro-batches\n        batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.reward_model_module))\n\n        # TODO: we may use the new schedule instead\n        # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)\n        if mpu.get_pipeline_model_parallel_world_size() > 1:\n            losses_reduced = forward_backward_func(\n                forward_step_func=forward_step,\n                data_iterator=batch_generator,\n                model=self.reward_model_module,\n                num_microbatches=n_micro_batch,\n                seq_length=total_seqlen,  # no use when input_shapes was set\n                micro_batch_size=1,  # no use when input_shapes was set\n                forward_only=True,\n            )\n        else:\n            losses_reduced = forward_backward_func(\n                forward_step_func=forward_step,\n                data_iterator=batch_generator,\n                model=self.reward_model_module,\n                num_microbatches=n_micro_batch,\n                seq_length=total_seqlen,  # in use for pp = 1\n                micro_batch_size=1,  # in use for pp = 1\n                forward_only=True,\n            )\n        # loss_reduces contains the stats returned from loss_func\n        losses_reduced = {\"output\": losses_reduced}\n        if use_dynamic_bsz:\n            losses_reduced[\"indices\"] = indices\n        return losses_reduced\n\n    def offload_params_to_cpu(self):\n        if self.device in [\"cuda\", \"npu\"]:\n            for reward_model_module in self.reward_model_module:\n                for name, param in reward_model_module.named_parameters():\n                    param.data = param.data.to(\"cpu\", non_blocking=True)\n            self.device = \"cpu\"\n            get_torch_device().empty_cache()\n\n    def load_params_to_cuda(self):\n        if self.device == \"cpu\":\n            for reward_model_module in self.reward_model_module:\n                for name, param in reward_model_module.named_parameters():\n                    param.data = param.data.to(get_device_id(), non_blocking=True)\n            self.device = get_device_name()\n"
  },
  {
    "path": "siirl/engine/rollout/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base import BaseRollout\nfrom .hf_rollout import HFRollout\n\n__all__ = [\"BaseRollout\", \"HFRollout\"]\n"
  },
  {
    "path": "siirl/engine/rollout/async_server.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport socket\nfrom abc import ABC, abstractmethod\nfrom typing import Any, Dict, List, Optional, Tuple, Type\n\n\nimport ray\n\nfrom starlette.requests import Request\n\nlogger = logging.getLogger(__file__)\n\n\ndef _get_free_port():\n    with socket.socket() as sock:\n        sock.bind((\"\", 0))\n        return sock.getsockname()[1]\n\n\nclass AsyncServerBase(ABC):\n    \"\"\"Base class for AsyncServer.\"\"\"\n\n    def __init__(self):\n        self.address = ray._private.services.get_node_ip_address()\n        self.port = None\n\n    @abstractmethod\n    async def chat_completion(self, raw_request: Request):\n        \"\"\"OpenAI chat completion API.\n\n        API reference: https://platform.openai.com/docs/api-reference/chat/create\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    async def generate(self, prompt_ids: List[int], sampling_params: Dict[str, Any], request_id: str) -> List[int]:\n        \"\"\"Generate response ids given prompt ids.\n\n        Args:\n            prompt_ids (List[int]): prompt ids\n            sampling_params (Dict[str, Any]): sampling params\n            request_id (str): request id\n\n        Returns:\n            List[int]: response ids\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    async def init_engine(self):\n        \"\"\"Init async LLM engine.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    async def wake_up(self):\n        \"\"\"Wake up engine to load model weights and build kv cache.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    async def sleep(self):\n        \"\"\"Sleep engine to offload model weights and discard kv cache.\"\"\"\n        raise NotImplementedError\n\n\ndef async_server_class(\n    rollout_backend: str, rollout_backend_module: Optional[str] = None, rollout_backend_class: Optional[str] = None\n) -> Type[AsyncServerBase]:\n    \"\"\"Get async server class.\n\n    Args:\n        rollout_backend: str, rollout backend type (alias), should be \"vllm\" or \"sglang\".\n        rollout_backend_module: Optional[str], import path of the rollout backend.\n        rollout_backend_class: Optional[str], class name of the rollout backend.\n\n    Returns:\n        Type[AsyncServerBase]: async server class.\n    \"\"\"\n    if rollout_backend_class is None and rollout_backend_module is None:\n        # If both are None, use the default backend class\n        # Do not change the original import behavior\n        # importlib.import_module and from ... import ... have subtle differences in ray\n\n        if rollout_backend == \"vllm\":\n            from siirl.engine.rollout.vllm_rollout.vllm_async_server import AsyncvLLMServer\n\n            return AsyncvLLMServer\n        elif rollout_backend == \"sglang\":\n            from siirl.engine.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer\n            return AsyncSglangServer\n        else:\n            raise NotImplementedError(f\"rollout backend {rollout_backend} is not supported\")\n\n    if rollout_backend_module is None or rollout_backend_class is None:\n        raise ValueError(\"rollout_backend_module and rollout_backend_class must be both provided for customization\")\n\n    from siirl.utils.extras.import_utils import load_extern_type\n\n    return load_extern_type(rollout_backend_module, rollout_backend_class)\n"
  },
  {
    "path": "siirl/engine/rollout/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom abc import ABC, abstractmethod\n\nfrom tensordict import TensorDict\n\n__all__ = [\"BaseRollout\"]\n\n\nclass BaseRollout(ABC):\n    \"\"\"Base class for rollout.\"\"\"\n\n    @abstractmethod\n    def generate_sequences(self, prompts: TensorDict) -> TensorDict:\n        \"\"\"Generate sequences\"\"\"\n        pass\n"
  },
  {
    "path": "siirl/engine/rollout/embodied_rollout.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport contextlib\nfrom contextlib import contextmanager\nimport os\nimport time\nimport multiprocessing\nfrom collections import defaultdict\nfrom datetime import datetime\n\nimport numpy as np\nfrom siirl.params import ActorRolloutRefArguments\nimport tensorflow as tf\nimport torch\nimport torch.distributed\nfrom PIL import Image\nfrom loguru import logger\nfrom tensordict import TensorDict\nfrom torch import nn\nfrom torch.nn.utils.rnn import pad_sequence\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom transformers import AutoProcessor, GenerationConfig\n\nfrom siirl.utils.model_utils.torch_functional import get_eos_mask\nimport siirl.utils.model_utils.torch_functional as siirl_F\nfrom siirl.engine.rollout.base import BaseRollout\nfrom siirl.utils.embodied.video_emb import VideoEmbeddingModel\n\n\nif multiprocessing.get_start_method(allow_none=True) != \"spawn\":\n    multiprocessing.set_start_method(\"spawn\", force=True)\n\nfrom siirl.environment.embodied import LIBEROAdapter\n\n\n__all__ = ['RobHFRollout']\n\nOPENVLA_V01_SYSTEM_PROMPT = (\n    \"A chat between a curious user and an artificial intelligence assistant. \"\n    \"The assistant gives helpful, detailed, and polite answers to the user's questions.\"\n)\n\n\n@contextmanager\ndef _timer(name: str, timing_dict: dict):\n    \"\"\"A context manager to measure execution time of a code block.\"\"\"\n    start_time = time.perf_counter()\n    yield\n    end_time = time.perf_counter()\n    timing_dict[name] = timing_dict.get(name, 0) + end_time - start_time\n\n\ndef crop_and_resize(image, crop_scale, batch_size):\n    \"\"\"\n    Center-crops an image to have area `crop_scale` * (original image area), and then resizes back\n    to original size. We use the same logic seen in the `dlimp` RLDS datasets wrapper to avoid\n    distribution shift at test time.\n\n    Args:\n        image: TF Tensor of shape (batch_size, H, W, C) or (H, W, C) and datatype tf.float32 with\n               values between [0,1].\n        crop_scale: The area of the center crop with respect to the original image.\n        batch_size: Batch size.\n    \"\"\"\n    # Convert from 3D Tensor (H, W, C) to 4D Tensor (batch_size, H, W, C)\n    assert image.shape.ndims == 3 or image.shape.ndims == 4\n    expanded_dims = False\n    if image.shape.ndims == 3:\n        image = tf.expand_dims(image, axis=0)\n        expanded_dims = True\n\n    # Get height and width of crop\n    new_heights = tf.reshape(tf.clip_by_value(tf.sqrt(crop_scale), 0, 1), shape=(batch_size,))\n    new_widths = tf.reshape(tf.clip_by_value(tf.sqrt(crop_scale), 0, 1), shape=(batch_size,))\n\n    # Get bounding box representing crop\n    height_offsets = (1 - new_heights) / 2\n    width_offsets = (1 - new_widths) / 2\n    bounding_boxes = tf.stack(\n        [\n            height_offsets,\n            width_offsets,\n            height_offsets + new_heights,\n            width_offsets + new_widths,\n        ],\n        axis=1,\n    )\n\n    # Crop and then resize back up\n    image = tf.image.crop_and_resize(image, bounding_boxes, tf.range(batch_size), (224, 224))\n\n    # Convert back to 3D Tensor (H, W, C)\n    if expanded_dims:\n        image = image[0]\n\n    return image\n\n\ndef center_crop_image(image):\n    batch_size = 1\n    crop_scale = 0.9\n\n    # Convert to TF Tensor and record original data type (should be tf.uint8)\n    image = tf.convert_to_tensor(np.array(image))\n    orig_dtype = image.dtype\n\n    # Convert to data type tf.float32 and values between [0,1]\n    image = tf.image.convert_image_dtype(image, tf.float32)\n\n    # Crop and then resize back to original size\n    image = crop_and_resize(image, crop_scale, batch_size)\n\n    # Convert back to original data type\n    image = tf.clip_by_value(image, 0, 1)\n    image = tf.image.convert_image_dtype(image, orig_dtype, saturate=True)\n\n    # Convert back to PIL Image\n    image = Image.fromarray(image.numpy())\n    image = image.convert(\"RGB\")\n    return image\n\n\nclass EmbodiedHFRollout(BaseRollout):\n\n    def __init__(self, module: nn.Module, config: ActorRolloutRefArguments):\n        super().__init__()\n        self.config = config\n        self.model = module\n        self.max_steps = {\n            \"libero_spatial\": 512,\n            \"libero_object\": 512,\n            \"libero_goal\": 512,\n            \"libero_10\": 512,\n            \"libero_90\": 512\n        }\n        self._embodied_processed = False\n\n        self._rank = torch.distributed.get_rank(\n        ) if torch.distributed.is_initialized() else 0\n        self._num_gpus_per_node = self.config.embodied.n_gpus_per_node\n\n        self.embedding_model = VideoEmbeddingModel(\n            model_path=config.embodied.video_embedding_model_path,\n            img_size=config.embodied.embedding_img_size,\n            enable_fp16=config.embodied.embedding_enable_fp16\n        )\n\n        self.enable_perf = os.environ.get(\"SIIRL_ENABLE_PERF\", \"0\") == \"1\"\n        self.embedding_model_offload = config.embodied.embedding_model_offload\n\n        # Initialize LIBEROAdapter\n        self.num_workers = self.config.embodied.env.num_envs\n        # Distribute workers across available GPUs based on rank\n        self.adapter = LIBEROAdapter(\n            env_name=self.config.embodied.env.env_name,\n            num_envs=self.num_workers,\n            max_steps=self.config.embodied.env.max_steps,\n            num_steps_wait=self.config.embodied.env.num_steps_wait,\n            model_family=self.config.embodied.env.model_family,\n            gpu_ids=[self._rank % self._num_gpus_per_node] # Run all workers on the same assigned GPU\n        )\n        logger.info(\n            f\"Initializing LIBEROAdapter with {self.num_workers} environments...\")\n\n    def close(self):\n        \"\"\"Gracefully shuts down the environment adapter.\"\"\"\n        logger.info(\"Closing LIBEROAdapter...\")\n        if hasattr(self, 'adapter') and self.adapter:\n            self.adapter.close()\n        logger.info(\"LIBEROAdapter closed.\")\n\n    def __del__(self):\n        # Ensure workers are closed when the object is garbage collected\n        self.close()\n\n    def embodied_preprocess(self):\n        self.processor = AutoProcessor.from_pretrained(self.config.model.path, trust_remote_code=True)\n\n        if self.config.embodied.embodied_type in [\"openvla\", \"openvla-oft\"]:\n            gpus = tf.config.experimental.list_physical_devices('GPU')\n            if gpus:\n                for gpu in gpus:\n                    tf.config.experimental.set_memory_growth(gpu, True)\n\n        if self.config.embodied.embodied_type in [\"openvla-oft\"]:\n            if self.config.embodied.unnorm_key not in self.model.norm_stats and f\"{self.config.embodied.unnorm_key}_no_noops\" in self.model.norm_stats:\n                self.config.embodied.unnorm_key = f\"{self.config.embodied.unnorm_key}_no_noops\"\n            assert self.config.embodied.unnorm_key in self.model.norm_stats, f\"Action un-norm key {self.config.embodied.unnorm_key} not found in VLA `norm_stats`!\"\n\n    def generate_sequences(self, prompts):\n        \"\"\"\n        Main entry point for generating sequences.\n        It splits a large batch of prompts into chunks that fit the number of workers,\n        processes each chunk to generate a rollout, and then concatenates the results.\n        This mimics the behavior of the original script to ensure data format compatibility.\n        \"\"\"\n        # Preprocess the VLA model only once\n        if not self._embodied_processed:\n            self.embodied_preprocess()\n            self._embodied_processed = True\n\n        tic = time.time()\n\n        total_batch_size = prompts.batch_size[0]\n        n_samples = prompts['n_samples'] if 'n_samples' in prompts else 1\n        assert self.num_workers >= n_samples, f\"rollout num_workers({self.num_workers}) must be >= n_samples({n_samples})\"\n        batch_size_per_chunk = self.num_workers\n        num_chunks = (total_batch_size + batch_size_per_chunk - 1) // batch_size_per_chunk\n        logger.info(f\"RobHFRollout.generate_sequences called with total batch size {total_batch_size}, \"\n                    f\"n_samples {n_samples}, num_workers {self.num_workers}, batch_size_per_chunk {batch_size_per_chunk}, \"\n                    f\"num_chunks {num_chunks}\")\n        \n        all_chunk_outputs = []\n\n        for i in range(num_chunks):\n            start_idx = i * batch_size_per_chunk\n            end_idx = min((i + 1) * batch_size_per_chunk, total_batch_size)\n            \n            # Slice the prompts to create a chunk\n            chunk_prompts = prompts[start_idx:end_idx]\n            \n            logger.info(\n                f\"--- Processing chunk {i+1}/{num_chunks}, size = {chunk_prompts.batch_size[0]} ---\")\n            \n            # Process one chunk and get its TensorDict output\n            chunk_output = self._generate_chunk_rollout(chunk_prompts)\n            all_chunk_outputs.append(chunk_output)\n\n        # Concatenate the TensorDict objects from all chunks\n        final_output = torch.cat(all_chunk_outputs)\n        logger.info(f\"RobHFRollout.generate_sequences finished for a single batch of size {final_output.batch_size[0]}\"\n                    f\", took {time.time() - tic:.2f} seconds\")\n        return final_output\n\n    def process_input(self,inputs:list, task_descriptions:list):\n        \n        batchdata = {\"input_ids\":[],\"attention_mask\":[],\"pixel_values\":[]}  \n        \n        for i in range(len(inputs)):\n            input = inputs[i]\n            task_description = task_descriptions[i]\n           \n            image = Image.fromarray(input[\"full_image\"]).convert(\"RGB\")\n            if self.config.embodied.center_crop:\n                image = center_crop_image(image)\n            prompt = f\"In: What action should the robot take to {task_description.lower()}?\\nOut:\"\n            batch_feature  = self.processor(prompt, image)\n            \n            if \"wrist_image\" in input.keys():\n                wrist_image = Image.fromarray(input[\"wrist_image\"]).convert(\"RGB\")\n                if self.config.embodied.center_crop:\n                    wrist_image = center_crop_image(wrist_image)\n                wrist_batch_feature = self.processor(prompt, wrist_image)\n                primary_pixel_values = batch_feature[\"pixel_values\"]\n                batch_feature[\"pixel_values\"] = torch.cat([primary_pixel_values] + [wrist_batch_feature[\"pixel_values\"]], dim=1)\n                \n            input_ids = batch_feature[\"input_ids\"]\n            attention_mask = batch_feature[\"attention_mask\"]\n            pixel_values = batch_feature[\"pixel_values\"]\n            \n            if not torch.all(input_ids[:, -1] == 29871):\n                input_ids = torch.cat(\n                    (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1\n                )\n                if self.config.embodied.embodied_type in [\"openvla-oft\"]:\n                    attention_mask = torch.cat(\n                        (attention_mask, torch.unsqueeze(torch.Tensor([True]).bool(), dim=0).to(attention_mask.device)), dim=1\n                    )\n            \n            batchdata[\"input_ids\"].append(input_ids)    \n            batchdata[\"attention_mask\"].append(attention_mask)    \n            batchdata[\"pixel_values\"].append(pixel_values)    \n        \n        \n        device = torch.device('cuda') \n        \n        if self.config.embodied.embodied_type in [\"openvla-oft\"]:\n            batchdata[\"input_ids\"] = [x.transpose(0, 1) for x in batchdata[\"input_ids\"]]\n            batchdata[\"attention_mask\"] = [x.transpose(0, 1) for x in batchdata[\"attention_mask\"]]\n            batchdata[\"input_ids\"] = pad_sequence(batchdata[\"input_ids\"], batch_first=True, padding_value=self.processor.tokenizer.pad_token_id).squeeze(-1).to(device)\n            batchdata[\"attention_mask\"] = pad_sequence(batchdata[\"attention_mask\"], batch_first=True, padding_value=0).squeeze(-1).to(device)\n            \n            padding_mask = batchdata[\"input_ids\"].ne(self.processor.tokenizer.pad_token_id)\n            assert  torch.all(padding_mask==batchdata[\"attention_mask\"].ne(0))\n            padding_mask = ~padding_mask\n            padding_mask = padding_mask.int() \n            sorted_indices = torch.argsort(padding_mask, dim=1, descending=True, stable=True)\n            batchdata[\"input_ids\"] = torch.gather(batchdata[\"input_ids\"], 1, sorted_indices)\n            batchdata[\"attention_mask\"] = torch.gather(batchdata[\"attention_mask\"], 1, sorted_indices)\n            \n            \n            batchdata[\"pixel_values\"] = torch.cat(batchdata[\"pixel_values\"] , dim=0).to(device)\n            assert torch.all(batchdata[\"attention_mask\"].ne(0) == batchdata[\"input_ids\"].ne(self.processor.tokenizer.pad_token_id))\n        else:\n            for key in [\"input_ids\", \"attention_mask\", \"pixel_values\"]:\n                batchdata[key] = torch.cat(batchdata[key], dim=0).to(device)\n\n        return batchdata\n\n    def _generate_chunk_rollout(self, prompts):\n        generate_tic = time.time()\n        self.model.eval()\n        \n        # Validation mode has no n_samples; training mode has n_samples for repeat\n        is_valid = 'n_samples' not in prompts.keys()\n        n_samples = 1\n        if not is_valid:\n            val = prompts['n_samples']\n            n_samples = val.item() if hasattr(val, 'item') else int(val)\n        global_steps = prompts['global_steps'] if 'global_steps' in prompts.keys() else 0\n\n        # dataloader already did repeat, rollout does NOT repeat\n        task_id = prompts['task_id']\n        trial_id = prompts['trial_id']\n        task_suite_name = prompts['task_suite_name']\n        \n        assert np.all(task_suite_name == self.config.embodied.env.env_name), \\\n            \"All task_suite_name in the batch must match the rollout config\"\n        max_steps = self.config.embodied.env.max_steps\n        chunk_size = task_id.size(0)\n\n        timing_dict = {}\n\n        # Reset environments using the adapter\n        with _timer(f\"adapter_reset\", timing_dict):\n            # This is a blocking call\n            init_data_list = self.adapter._blocking_reset(\n                task_ids=task_id.reshape(-1).cpu().numpy().tolist(),\n                trial_ids=trial_id.reshape(-1).cpu().numpy().tolist(),\n            )\n\n        inputs = [None] * chunk_size\n        task_descriptions = [None] * chunk_size\n        task_records = [None] * chunk_size\n        valid_video = defaultdict(list)\n        all_video = defaultdict(list)\n\n        # Collect initial observations for the chunk\n        with _timer(f\"process_initial_obs\", timing_dict):\n            for idx in range(chunk_size):\n                init_data = init_data_list[idx]\n                task_descriptions[idx] = init_data[\"task_description\"]\n                inputs[idx] = self._obs_to_input(init_data['obs'])\n                task_records[idx] = {\n                    \"active\": init_data['active'],\n                    \"complete\": init_data['complete'],\n                    \"finish_step\": init_data['finish_step'],\n                    \"task_file_name\": init_data['task_file_name']\n                }\n                if is_valid:\n                    valid_video[init_data['task_file_name']].extend(\n                        init_data['valid_images'])\n                all_video[init_data['task_file_name']].extend(\n                    init_data['valid_images'])\n\n        step = 0\n        vla_history = []\n        meta_info_keys = [\"eos_token_id\", \"pad_token_id\", \"recompute_log_prob\", \"validate\", \"do_sample\", \"global_steps\"]\n        meta_info_keys = [key for key in meta_info_keys if key in prompts.keys()]\n        meta_info = prompts.select(*meta_info_keys)\n        while step < max_steps:\n            active_indices = [i for i, r in enumerate(task_records) if r['active']]\n            current_inputs = inputs\n            current_task_descriptions = task_descriptions\n\n            with _timer(f\"process_input\", timing_dict):\n                vla_input = self.process_input(current_inputs, current_task_descriptions)\n            vla_input.update(meta_info)\n\n            with _timer(f\"_generate_one_step\", timing_dict):\n                vla_output = self._generate_one_step(vla_input)\n            \n            actions = vla_output[\"action\"]\n\n            step_data = {\n                \"responses\": vla_output[\"responses\"],\n                \"input_ids\": vla_output[\"input_ids\"],\n                \"attention_mask\": vla_output[\"attention_mask\"],\n                \"pixel_values\": vla_output[\"pixel_values\"],\n                \"action\": actions,\n                \"step\": step\n            }\n            vla_history.append(step_data)\n\n            with _timer(f\"adapter_step\", timing_dict):\n                step_results_list = self.adapter._blocking_step({\n                    \"indices\": active_indices,\n                    \"actions\": actions,\n                })\n            \n            with _timer(f\"process_step_results\", timing_dict):\n                new_inputs = inputs.copy()\n                for idx in active_indices:\n                    result = step_results_list[idx]\n                    new_inputs[idx] = self._obs_to_input(result['obs'])\n                    task_records[idx]['active'] = result['active']\n                    task_records[idx]['complete'] = result['complete']\n                    task_records[idx]['finish_step'] = result['finish_step']\n                    all_video[task_records[idx]['task_file_name']].extend(result['valid_images'])\n                    if is_valid:\n                        valid_video[task_records[idx]['task_file_name']].extend(result['valid_images'])\n                inputs = new_inputs\n            \n            step += self.config.embodied.action_chunks_len\n        \n        with _timer(f\"post_loop_processing\", timing_dict):\n            torch.cuda.empty_cache()\n            self.model.train()\n            \n            batch = {\n                    'responses': [],\n                    'input_ids': [],  # here input_ids become the whole sentences\n                    'attention_mask': [],\n                    'pixel_values': [],\n                }\n            for k in [\"responses\", \"input_ids\", \"attention_mask\", \"pixel_values\"]:\n                for h in vla_history:\n                    batch[k].append(h[k])\n            \n            for k,v in batch.items():\n                batch[k] = torch.stack(v,dim=1) \n    \n            batch[\"complete\"] = []\n            batch[\"finish_step\"] = []\n            batch[\"task_file_name\"] = []\n    \n            for k in task_records:\n                batch[\"complete\"].append(k[\"complete\"])\n                batch[\"finish_step\"].append(k[\"finish_step\"])\n                batch[\"task_file_name\"].append(k[\"task_file_name\"])\n            \n            batch[\"complete\"] = torch.tensor(batch[\"complete\"], dtype=torch.bool, device=batch['responses'].device)\n            batch[\"finish_step\"] = torch.tensor(batch[\"finish_step\"], dtype=torch.int64, device=batch['responses'].device)\n            # Build batch\n            names = batch[\"task_file_name\"]\n            max_len = 50 # max(len(n) for n in names)\n            padded = [n.ljust(max_len, '\\0') for n in names]\n            batch[\"task_file_name\"] = torch.tensor(\n                [s.encode('utf-8') for s in padded],\n                dtype=torch.uint8,\n                device=batch['responses'].device\n            )\n\n        vjepa_embeddings = []\n        tasks_for_embedding = []\n        for k in task_records:\n            tasks_for_embedding.append((\n                k['task_file_name'],\n                all_video.get(k['task_file_name'], []),\n                \"rollout4embedding\",\n                global_steps,\n                k['complete']\n            ))\n        \n        with _timer(f\"get_embeddings\", timing_dict):\n            batch_names, batch_frames = zip(*[(t[0], t[1])  for t in tasks_for_embedding])\n            vjepa_embeddings = self.embedding_model.get_embeddings(batch_names, batch_frames)\n            batch[\"vjepa_embedding\"] = torch.tensor(\n                np.array(vjepa_embeddings), dtype=torch.float32)\n\n        if self.enable_perf:\n            generate_chunk_rollout_time = time.time() - generate_tic\n            log_str = f\"\\n--- ⏱️  Chunk Performance (size={chunk_size}) ---\\n\"\n            \n            # Sort the dictionary by value in descending order for better readability\n            sorted_timing = sorted(timing_dict.items(), key=lambda item: item[1], reverse=True)\n\n            for key, value in sorted_timing:\n                log_str += f\"  {key}: {value:.4f} seconds\\n\"\n            log_str += f\"  _generate_chunk_rollout: {generate_chunk_rollout_time:.4f} seconds\\n\"\n            log_str += f\"  total steps in chunk: {step}\\n\"\n            log_str += \"--- ⏱️  End of Chunk Performance Log ---\\n\"\n\n            with open(f\"rollout_performance_rank_{self._rank}.log\", \"a\") as f:\n                f.write(f\"\\n{datetime.now()}:\\n\")\n                f.write(log_str)\n            logger.info(log_str)\n        \n        output_batch = TensorDict(\n            batch,\n            batch_size=chunk_size)\n\n        return output_batch\n\n    @torch.no_grad()\n    def _generate_one_step(self, prompts: dict):\n        if self.config.embodied.embodied_type == \"openvla-oft\":\n            idx = prompts['input_ids']  # (bs, prompt_length)\n            attention_mask = prompts['attention_mask']  # left-padded attention_mask\n            pixel_values = prompts[\"pixel_values\"]\n        \n        \n            param_ctx = contextlib.nullcontext()\n\n            # make sampling args can be overriden by inputs\n            do_sample = prompts.get('do_sample', self.config.rollout.do_sample)\n        \n\n            temperature = prompts.get('temperature', self.config.rollout.temperature)\n\n            #generation_config = GenerationConfig(temperature=temperature, top_p=top_p, top_k=top_k)\n\n            if isinstance(self.model, FSDP):\n                # recurse need to set to False according to https://github.com/pytorch/pytorch/issues/100069\n                param_ctx = FSDP.summon_full_params(self.model, writeback=False, recurse=False)\n            \n            with param_ctx:\n                with torch.autocast(device_type='cuda', dtype=torch.bfloat16):\n                    actions, response = self.model.generate_action_verl(\n                        input_ids=idx,\n                        pixel_values=pixel_values,\n                        attention_mask=attention_mask,\n                        padding_idx = self.processor.tokenizer.pad_token_id,\n                        do_sample=do_sample,\n                        unnorm_key=self.config.embodied.unnorm_key,\n                        temperature=temperature, )\n            \n            \n            assert self.processor.tokenizer.pad_token_id is not None\n\n            assert idx.ndim == 2\n            idx = siirl_F.pad_sequence_to_length(idx,max_seq_len=self.config.rollout.prompt_length,pad_token_id=self.processor.tokenizer.pad_token_id,left_pad=True)\n            \n            assert attention_mask.ndim == 2\n            attention_mask = siirl_F.pad_sequence_to_length(attention_mask,max_seq_len=self.config.rollout.prompt_length,pad_token_id=0,left_pad=True)\n            \n            \n            assert idx.device.type == 'cuda'\n            assert response.device.type == 'cuda'\n            #assert seq.device.type == 'cuda'\n            assert attention_mask.device.type == 'cuda'\n            assert pixel_values.device.type == 'cuda'\n            batch ={\n                    'responses': response,\n                    'input_ids': idx,\n                    'attention_mask': attention_mask,\n                    \"pixel_values\":pixel_values,\n                    \"action\":actions,\n                }\n\n            return batch\n        \n        elif self.config.embodied.embodied_type == \"openvla\": \n            idx = prompts['input_ids']  # (bs, prompt_length)\n            attention_mask = prompts['attention_mask']  # left-padded attention_mask\n            pixel_values = prompts[\"pixel_values\"]\n            \n            # used to construct attention_mask\n            eos_token_id = prompts['eos_token_id']\n            pad_token_id = prompts['pad_token_id']\n\n            batch_size = idx.size(0)\n            prompt_length = idx.size(1)\n            #self.model.eval()\n            param_ctx = contextlib.nullcontext()\n\n            do_sample = prompts.get('do_sample', self.config.rollout.do_sample)\n            response_length =  self.model.get_action_dim(self.config.embodied.unnorm_key)\n            top_p = prompts.get('top_p', self.config.rollout.top_p)\n            top_k = prompts.get('top_k', self.config.rollout.top_k)\n            if top_k is None:\n                top_k = 0\n            top_k = max(0, top_k)  # to be compatible with vllm\n\n            temperature = prompts.get('temperature', self.config.rollout.temperature)\n            generation_config = GenerationConfig(temperature=temperature, top_p=top_p, top_k=top_k)\n\n            if isinstance(self.model, FSDP):\n                # recurse need to set to False according to https://github.com/pytorch/pytorch/issues/100069\n                param_ctx = FSDP.summon_full_params(self.model, writeback=False, recurse=False)\n            \n            with param_ctx:\n                with torch.autocast(device_type='cuda', dtype=torch.bfloat16):\n                    \n                    output = self.model.generate(\n                        input_ids=idx,\n                        pixel_values=pixel_values,\n                        attention_mask=attention_mask,\n                        do_sample=do_sample,\n                        max_new_tokens=response_length,\n                        # max_length=max_length,\n                        eos_token_id=eos_token_id,\n                        pad_token_id=pad_token_id,\n                        generation_config=generation_config,\n                        # renormalize_logits=True,\n                        output_scores=False,  # this is potentially very large\n                        return_dict_in_generate=True,\n                        use_cache=True)\n                    \n           \n            seq = output.sequences\n            sequence_length = prompt_length + response_length\n            delta_length = sequence_length - seq.shape[1]\n            \n            assert delta_length == 0\n            assert seq.shape[1] == sequence_length\n\n            prompt = seq[:, :prompt_length]  # (bs, prompt_length)\n            response = seq[:, prompt_length:]  # (bs, response_length)\n\n            response_length = response.size(1)\n            #delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)\n            #delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)\n            #response_position_ids = position_ids[:, -1:] + delta_position_id\n            #position_ids = torch.cat([position_ids, response_position_ids], dim=-1)\n\n            response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)\n            attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)\n\n            # Extract predicted action tokens and translate into (normalized) continuous actions\n            predicted_action_token_ids = response.detach().cpu().numpy()\n            discretized_actions = self.model.vocab_size - predicted_action_token_ids\n            discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.model.bin_centers.shape[0] - 1)\n            normalized_actions = self.model.bin_centers[discretized_actions]\n\n            # Unnormalize actions\n            action_norm_stats = self.model.get_action_stats(self.config.embodied.unnorm_key)\n            mask = action_norm_stats.get(\"mask\", np.ones_like(action_norm_stats[\"q01\"], dtype=bool))\n            action_high, action_low = np.array(action_norm_stats[\"q99\"]), np.array(action_norm_stats[\"q01\"])\n            actions = np.where(\n                mask,\n                0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low,\n                normalized_actions,\n            )\n            \n            actions = np.expand_dims(actions, axis=1)\n            \n            assert self.processor.tokenizer.pad_token_id is not None\n            assert prompt.ndim == 2\n            prompt = siirl_F.pad_sequence_to_length(prompt,max_seq_len=self.config.rollout.prompt_length,pad_token_id=self.processor.tokenizer.pad_token_id,left_pad=True)\n            assert seq.ndim == 2\n            seq = siirl_F.pad_sequence_to_length(seq,max_seq_len=self.config.rollout.prompt_length,pad_token_id=self.processor.tokenizer.pad_token_id,left_pad=True)\n            assert attention_mask.ndim == 2\n            attention_mask = siirl_F.pad_sequence_to_length(attention_mask,max_seq_len=self.config.rollout.prompt_length,pad_token_id=0,left_pad=True)\n            \n            batch ={\n                    'prompts': prompt,\n                    'responses': response,\n                    'input_ids': seq,\n                    'attention_mask': attention_mask,\n                    \"pixel_values\":pixel_values,\n                    \"action\":actions,\n                    #'position_ids': position_ids\n                }\n            \n            return batch\n\n    def _obs_to_input(self, obs):\n        from siirl.utils.embodied.libero_utils import get_libero_image, get_libero_wrist_image, quat2axisangle\n\n        if self.config.embodied.num_images_in_input > 1:\n            return {\n                \"full_image\": get_libero_image(obs, 224),\n                \"wrist_image\": get_libero_wrist_image(obs, 224),\n                \"state\": np.concatenate([\n                    obs[\"robot0_eef_pos\"],\n                    quat2axisangle(obs[\"robot0_eef_quat\"]),\n                    obs[\"robot0_gripper_qpos\"]\n                ])\n            }\n        else:\n            return {\n                \"full_image\": get_libero_image(obs, 224),\n                \"state\": np.concatenate([\n                    obs[\"robot0_eef_pos\"],\n                    quat2axisangle(obs[\"robot0_eef_quat\"]),\n                    obs[\"robot0_gripper_qpos\"]\n                ])\n            }\n"
  },
  {
    "path": "siirl/engine/rollout/hf_rollout.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nRollout with huggingface models.\nTODO: refactor this class. Currently, it will hang when using FSDP HybridShard. We should actually create a single\nGPU model. Then, get full state_dict and bind the state_dict to the single GPU model. Then, use the single GPU model\nto perform generation.\n\"\"\"\n\nimport contextlib\n\nimport torch\nimport torch.distributed\nfrom tensordict import TensorDict\nfrom torch import nn\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom transformers import GenerationConfig\n\nfrom siirl.utils.extras.device import get_device_name, get_torch_device\nfrom siirl.utils.model_utils.torch_functional import get_response_mask\n\nfrom siirl.engine.rollout.base import BaseRollout\n\n__all__ = [\"HFRollout\"]\n\n\nclass HFRollout(BaseRollout):\n    def __init__(self, module: nn.Module, config):\n        super().__init__()\n        self.config = config\n        self.module = module\n\n    def generate_sequences(self, prompts: TensorDict) -> TensorDict:\n        batch_size = prompts.batch.batch_size[0]\n        num_chunks = max(batch_size // self.config.get(\"micro_batch_size\", batch_size), 1)\n        batch_prompts = prompts.chunk(chunks=num_chunks)\n        output = [self._generate_minibatch(p) for p in batch_prompts]\n        output = TensorDict.concat(output)\n        return output\n\n    @torch.no_grad()\n    def _generate_minibatch(self, prompts: TensorDict) -> TensorDict:\n        # make sampling args can be overridden by inputs\n        do_sample = prompts.get(\"do_sample\", self.config.do_sample)\n        is_validate = prompts.get(\"validate\", False)\n\n        temperature = prompts.get(\"temperature\", self.config.temperature)\n        response_length = prompts.get(\"response_length\", self.config.response_length)\n        top_p = prompts.get(\"top_p\", self.config.get(\"top_p\", 1.0))\n        top_k = max(0, prompts.get(\"top_k\", self.config.get(\"top_k\", 0)))  # to be compatible with vllm\n\n        if not do_sample:\n            # do_sample==False -> greedy decoding\n            kwargs = {\n                \"do_sample\": False,\n                \"num_beams\": 1,\n            }\n        elif is_validate:\n            # do validate and do sample -> use val_kwargs\n            kwargs = {\n                \"do_sample\": True,\n                \"num_beams\": 1,\n                \"top_k\": max(0, self.config.val_kwargs.top_k),  # to be compatible with vllm\n                \"top_p\": self.config.val_kwargs.top_p,\n                \"temperature\": self.config.val_kwargs.temperature,\n                \"num_return_sequences\": 1,  # if validate, already repeat in ray_trainer\n            }\n        else:\n            # do_sample -> use rollout config\n            kwargs = {\n                \"do_sample\": True,\n                \"num_beams\": 1,\n                \"top_p\": top_p,\n                \"top_k\": top_k,\n                \"temperature\": temperature,\n                \"num_return_sequences\": 1, # already repeat in ray_trainer\n            }\n\n        # make config according to generate mode\n        generation_config = GenerationConfig(**kwargs)\n\n        idx = prompts[\"input_ids\"]  # (bs, prompt_length)\n        prompt_length = idx.size(1)\n        attention_mask = prompts[\"attention_mask\"]  # left-padded attention_mask\n        position_ids = prompts[\"position_ids\"]\n\n        # used to construct attention_mask\n        eos_token_id = prompts[\"eos_token_id\"]\n        pad_token_id = prompts[\"pad_token_id\"]\n\n        self.module.eval()\n        param_ctx = contextlib.nullcontext()\n\n        if isinstance(self.module, FSDP):\n            # recurse need to set to False according to https://github.com/pytorch/pytorch/issues/100069\n            param_ctx = FSDP.summon_full_params(self.module, writeback=False, recurse=False)\n        with param_ctx, torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16):\n            output = self.module.generate(\n                input_ids=idx,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                do_sample=do_sample,\n                max_new_tokens=response_length,\n                eos_token_id=eos_token_id,\n                pad_token_id=pad_token_id,\n                generation_config=generation_config,\n                output_scores=False,  # this is potentially very large\n                return_dict_in_generate=True,\n                use_cache=True,\n            )\n\n        # TODO: filter out the seq with no answers like ds-chat\n        seq = output.sequences\n        generated_batch_size = seq.size(0)  # bs * num_return_sequences\n\n        # huggingface generate will stop generating when all the batch reaches [EOS].\n        # We have to pad to response_length\n        sequence_length = prompt_length + self.config.response_length\n        delta_length = sequence_length - seq.shape[1]\n\n        if delta_length > 0:\n            delta_tokens = torch.ones(size=(generated_batch_size, delta_length), device=seq.device, dtype=seq.dtype)\n            delta_tokens = pad_token_id * delta_tokens\n            seq = torch.cat((seq, delta_tokens), dim=1)\n        assert seq.shape[1] == sequence_length\n\n        # make necessary reputations if num_return_sequences > 1\n        num_return_sequences = kwargs.get(\"num_return_sequences\", 1)\n        if num_return_sequences > 1:\n            position_ids = position_ids.repeat_interleave(num_return_sequences, dim=0)\n            attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0)\n\n        prompt = seq[:, :prompt_length]  # (generated_batch_size, prompt_length)\n        response = seq[:, prompt_length:]  # (generated_batch_size, response_length)\n\n        response_length = response.size(1)\n        delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)\n        delta_position_id = delta_position_id.unsqueeze(0).repeat(generated_batch_size, 1)\n\n        response_position_ids = position_ids[:, -1:] + delta_position_id\n        position_ids = torch.cat([position_ids, response_position_ids], dim=-1)\n\n        response_attention_mask = get_response_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)\n        attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)\n\n        batch = TensorDict(\n            {\n                \"prompts\": prompt,\n                \"responses\": response,\n                \"input_ids\": seq,\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=generated_batch_size,\n        )\n\n        # empty cache before compute old_log_prob\n        get_torch_device().empty_cache()\n\n        self.module.train()\n        return batch\n"
  },
  {
    "path": "siirl/engine/rollout/schemas.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport difflib\nimport logging\nimport os\nfrom enum import Enum\nfrom typing import Any, Dict, List, Optional, Union\n\nimport torch\nfrom pydantic import BaseModel, ConfigDict, model_validator\nfrom transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin\n\nfrom siirl.execution.rollout_flow.multiturn.tools.schemas import OpenAIFunctionToolCall, OpenAIFunctionToolSchema\nfrom siirl.utils.model_utils.model import compute_position_id_with_mask\nfrom loguru import logger\n\n\nBASE_CHAT_HISTORY = [\n    {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n    {\"role\": \"user\", \"content\": \"I am a user.\"},\n]\n\n\nclass FinishReasonTypeEnum(str, Enum):\n    \"\"\"The enum for finish reason type.\"\"\"\n\n    LENGTH = \"length\"\n    STOP = \"stop\"\n    TOOL_CALL = \"tool_calls\"\n\n    @classmethod\n    def from_str(cls, value: str) -> \"FinishReasonTypeEnum\":\n        if value == \"stop\":\n            return cls.STOP\n        elif value == \"length\":\n            return cls.LENGTH\n        elif value == \"tool_calls\":\n            return cls.TOOL_CALL\n        else:\n            raise ValueError(f\"Unsupported finish reason type: {value}\")\n\n\nclass Message(BaseModel):\n    role: str\n    content: str | Dict[str, Any] | List[Dict[str, Any]]\n    tool_calls: Optional[List[OpenAIFunctionToolCall]] = None\n\n\nclass AsyncRolloutRequestStateEnum(str, Enum):\n    \"\"\"The enum for async rollout request state.\"\"\"\n\n    PENDING = \"pending\"\n    RUNNING = \"running\"\n    COMPLETED = \"completed\"\n    FAILED = \"failed\"\n    TOOL_CALLING = \"tool_calling\"\n    INTERACTING = \"interacting\"\n\n\nclass TokenizationSanityCheckModeEnum(str, Enum):\n    \"\"\"The enum for tokenization sanity check mode.\"\"\"\n\n    DISABLE = \"disable\"\n    STRICT = \"strict\"\n    IGNORE_STRIPPABLE = \"ignore_strippable\"\n\n\nclass AsyncRolloutRequest(BaseModel):\n    \"\"\"The data model for async rollout.\"\"\"\n\n    model_config = ConfigDict(arbitrary_types_allowed=True)\n\n    batch_data_id: int = 0\n    rollout_offset: int = 0\n    request_id: str\n    state: AsyncRolloutRequestStateEnum\n    messages: List[Message]\n    multi_modal_keys: Optional[List[str]] = None\n    multi_modal_data: Optional[Dict[str, Any]] = None\n    multi_modal_inputs: Optional[Dict[str, torch.Tensor]] = None\n    tool_schemas: Optional[List[OpenAIFunctionToolSchema]] = None\n    tools_kwargs: Dict[str, Any] = {}\n    interaction_kwargs: Dict[str, Any] = {}\n    input_ids: Optional[torch.Tensor] = None\n    prompt_ids: Optional[torch.Tensor] = None\n    response_ids: Optional[torch.Tensor] = None\n    attention_mask: Optional[torch.Tensor] = None\n    prompt_attention_mask: Optional[torch.Tensor] = None\n    response_attention_mask: Optional[torch.Tensor] = None\n    position_ids: Optional[torch.Tensor] = None\n    prompt_position_ids: Optional[torch.Tensor] = None\n    response_position_ids: Optional[torch.Tensor] = None\n    loss_mask: Optional[torch.Tensor] = None\n    prompt_loss_mask: Optional[torch.Tensor] = None\n    response_loss_mask: Optional[torch.Tensor] = None\n    reward_scores: Dict[str, float]\n    max_prompt_len: int\n    max_response_len: int = 8192\n    max_model_len: int = 32768\n    metrics: Dict[str, List[Any]] = {}\n\n    use_inference_chat_template: bool\n    tokenization_sanity_check_mode: TokenizationSanityCheckModeEnum\n    generation_prompt_ids: Optional[torch.Tensor] = None\n    base_conv_wo_gen_prompt_end_pos: int\n    base_conv_with_gen_prompt_end_pos: int\n\n    @model_validator(mode=\"before\")\n    @classmethod\n    def initialize_request(cls, values):\n        if not (messages := values.get(\"messages\")):\n            raise ValueError(\"messages is required for AsyncRolloutRequest initialization\")\n        if not (max_prompt_len := values.get(\"max_prompt_len\")):\n            raise ValueError(\"max_prompt_len is required for AsyncRolloutRequest initialization\")\n        if not (processing_class := values.pop(\"processing_class\", None)):\n            raise ValueError(\"processing_class is required for AsyncRolloutRequest initialization\")\n\n        values[\"messages\"] = [Message.model_validate(msg) for msg in messages]\n\n        # If there is no multi_modal_keys, we assume the multi-modal data is image and video.\n        if not values.get(\"multi_modal_keys\"):\n            values[\"multi_modal_keys\"] = [\"image\", \"video\"]\n        if not values.get(\"multi_modal_data\"):\n            values[\"multi_modal_data\"] = {key: [] for key in values[\"multi_modal_keys\"]}\n        else:\n            # check if all multi_modal_keys are in multi_modal_data\n            for key in values[\"multi_modal_keys\"]:\n                if key not in values[\"multi_modal_data\"]:\n                    values[\"multi_modal_data\"][key] = []\n        if not values.get(\"multi_modal_inputs\"):\n            values[\"multi_modal_inputs\"] = {}\n\n        tools = (\n            [tool.model_dump() for tool in tool_schemas] if (tool_schemas := values.get(\"tool_schemas\", [])) else None\n        )\n\n        multi_modal_data = values[\"multi_modal_data\"]\n        tokens_without_prompt = cls._handle_apply_chat_template(\n            processing_class,\n            messages,\n            multi_modal_data=multi_modal_data,\n            tools=tools,\n            add_generation_prompt=False,\n            tokenize=True,\n        )\n        if (\n            values.get(\"input_ids\") is None\n            or values.get(\"attention_mask\") is None\n            or values.get(\"position_ids\") is None\n        ):\n            tokenization_dict_with_prompt = cls._handle_apply_chat_template(\n                processing_class,\n                messages,\n                multi_modal_data=multi_modal_data,\n                tools=tools,\n                add_generation_prompt=True,\n                tokenize=True,\n                return_dict=True,\n            )\n\n            values[\"input_ids\"], values[\"attention_mask\"] = (\n                tokenization_dict_with_prompt[\"input_ids\"],\n                tokenization_dict_with_prompt[\"attention_mask\"],\n            )\n            if values[\"input_ids\"].shape[-1] > max_prompt_len:\n                # Only log the warning to avoid truncating in the middle of generation prompt. Consider raising an\n                # error for this case in the future.\n                logger.warning(\n                    f\"Prompt {values['batch_data_id']} has length {values['input_ids'].shape[-1]} \"\n                    f\"which is greater than max_prompt_len {max_prompt_len} after applied chat template with tools.\"\n                )\n\n            # Process multi_modal_inputs\n            multi_modal_inputs = tokenization_dict_with_prompt.copy()\n            multi_modal_inputs.pop(\"input_ids\", None)\n            multi_modal_inputs.pop(\"attention_mask\", None)\n            values[\"multi_modal_inputs\"] = multi_modal_inputs\n\n            values[\"position_ids\"] = values[\"prompt_position_ids\"] = cls._get_position_ids(\n                processing_class, values[\"input_ids\"], values[\"attention_mask\"], multi_modal_inputs\n            )\n\n        values[\"prompt_ids\"], values[\"prompt_attention_mask\"] = values[\"input_ids\"], values[\"attention_mask\"]\n        values[\"loss_mask\"] = values[\"prompt_loss_mask\"] = torch.zeros_like(values[\"input_ids\"], dtype=torch.bool)\n        values[\"generation_prompt_ids\"] = values[\"input_ids\"][..., tokens_without_prompt.shape[-1] :]\n        values[\"base_conv_wo_gen_prompt_end_pos\"] = cls._handle_apply_chat_template(\n            processing_class,\n            BASE_CHAT_HISTORY,\n            multi_modal_data=multi_modal_data,\n            tools=tools,\n            add_generation_prompt=False,\n            tokenize=True,\n        ).shape[-1]\n\n        values[\"base_conv_with_gen_prompt_end_pos\"] = cls._handle_apply_chat_template(\n            processing_class,\n            BASE_CHAT_HISTORY,\n            multi_modal_data=multi_modal_data,\n            tools=tools,\n            add_generation_prompt=True,\n            tokenize=True,\n        ).shape[-1]\n\n        return values\n\n    @staticmethod\n    def _handle_apply_chat_template(\n        processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin],\n        messages: List[Message],\n        multi_modal_data: Dict[str, Any],\n        tools: Optional[List[OpenAIFunctionToolSchema]] = None,\n        add_generation_prompt: bool = False,\n        tokenize: bool = False,\n        return_dict: bool = False,\n    ):\n        raw_prompt = processing_class.apply_chat_template(\n            messages, tools=tools, add_generation_prompt=add_generation_prompt, tokenize=False\n        )\n        if not tokenize:\n            return raw_prompt\n\n        if isinstance(processing_class, PreTrainedTokenizer) or isinstance(processing_class, PreTrainedTokenizerFast):\n            if any(len(values) > 0 for values in multi_modal_data.values()):\n                logger.warning(\n                    \"There is multi_modal_data but you are not using a processor. Multi-modal data will be ignored.\"\n                )\n            model_inputs = processing_class(text=[raw_prompt], return_tensors=\"pt\")\n        elif isinstance(processing_class, ProcessorMixin):\n            # When we update multi_model_keys, we also need to update this logic\n            images = images if len(images := multi_modal_data.get(\"image\", [])) > 0 else None\n            videos = videos if len(videos := multi_modal_data.get(\"video\", [])) > 0 else None\n            model_inputs = processing_class(text=[raw_prompt], images=images, videos=videos, return_tensors=\"pt\")\n        else:\n            raise ValueError(f\"Unsupported processing class type: {type(processing_class)}\")\n\n        model_inputs = dict(model_inputs)\n        if return_dict:\n            return model_inputs\n        else:\n            return model_inputs[\"input_ids\"]\n\n    @staticmethod\n    def _get_position_ids(\n        processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin],\n        input_ids: torch.Tensor,\n        attention_mask: torch.Tensor,\n        multi_modal_inputs: Optional[Dict[str, torch.Tensor]] = None,\n    ) -> torch.Tensor:\n        # special case for qwen2vl\n        is_qwen2vl = (\n            hasattr(processing_class, \"image_processor\")\n            and \"Qwen2VLImageProcessor\" in processing_class.image_processor.__class__.__name__\n        )\n        if is_qwen2vl:\n            from siirl.models.transformers.qwen2_vl import get_rope_index\n\n            image_grid_thw = video_grid_thw = second_per_grid_ts = None\n            if multi_modal_inputs:\n                image_grid_thw = multi_modal_inputs.get(\"image_grid_thw\")\n                video_grid_thw = multi_modal_inputs.get(\"video_grid_thw\")\n                second_per_grid_ts = multi_modal_inputs.get(\"second_per_grid_ts\")\n\n            assert input_ids.dim() == 2 and input_ids.shape[0] == 1, (\n                f\"input_ids should be 2D with batch size 1, but got shape {input_ids.shape}\"\n            )\n            assert attention_mask.dim() == 2 and attention_mask.shape[0] == 1, (\n                f\"attention_mask should be 2D with batch size 1, but got shape {attention_mask.shape}\"\n            )\n            new_position_ids = get_rope_index(\n                processing_class,\n                input_ids=input_ids.squeeze(0),\n                image_grid_thw=image_grid_thw,\n                video_grid_thw=video_grid_thw,\n                second_per_grid_ts=second_per_grid_ts,\n                attention_mask=attention_mask.squeeze(0),\n            )\n            return new_position_ids  # (3, seq_len)\n        else:\n            return compute_position_id_with_mask(attention_mask)  # (1, seq_len)\n\n    def _update_input_ids(\n        self,\n        processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin],\n        new_input_ids: torch.Tensor,\n        attention_mask: bool,\n        loss_mask: bool,\n        new_multi_modal_inputs: Optional[Dict[str, torch.Tensor]] = None,\n    ) -> None:\n        \"\"\"\n        Update the input_ids, attention_mask, position_ids, and loss_mask of the request in additive manner.\n        \"\"\"\n        self.input_ids = torch.cat([self.input_ids, new_input_ids], dim=-1)\n        attention_mask = torch.ones_like(new_input_ids) * int(attention_mask)\n        self.attention_mask = torch.cat([self.attention_mask, attention_mask], dim=-1)\n        loss_mask = torch.ones_like(new_input_ids) * int(loss_mask)\n        self.loss_mask = torch.cat([self.loss_mask, loss_mask], dim=-1)\n\n        if new_multi_modal_inputs:\n            self._update_multi_modal_inputs(new_multi_modal_inputs)\n\n        new_position_ids = self._get_position_ids(\n            processing_class, new_input_ids, attention_mask, new_multi_modal_inputs\n        )\n\n        last_pos = self.position_ids[..., -1:]\n        new_position_ids = new_position_ids + (last_pos + 1)\n\n        self.position_ids = torch.cat([self.position_ids, new_position_ids], dim=-1)\n\n        assert (\n            self.input_ids.shape[-1]\n            == self.attention_mask.shape[-1]\n            == self.position_ids.shape[-1]\n            == self.loss_mask.shape[-1]\n        ), f\"\"\"Request {self.request_id} has different length of {self.input_ids.shape[-1]=}, \n            {self.attention_mask.shape[-1]=}, {self.position_ids.shape[-1]=}, {self.loss_mask.shape[-1]=}\"\"\"\n\n    def _update_multi_modal_inputs(self, new_multi_modal_inputs: Dict[str, torch.Tensor]) -> None:\n        \"\"\"\n        Update the multi_modal_inputs of the request in additive manner.\n        \"\"\"\n        for key in new_multi_modal_inputs:\n            input_tensor = new_multi_modal_inputs[key]\n            self.multi_modal_inputs[key] = (\n                torch.cat([self.multi_modal_inputs[key], input_tensor], dim=0)\n                if key in self.multi_modal_inputs\n                else input_tensor\n            )\n\n    def get_generation_prompt_ids(\n        self, processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin]\n    ) -> List[int]:\n        \"\"\"\n        Get the generation prompt ids for rollout engine.\n\n        Because rollout engine(SGLang) requires the ids to be a list, we need to convert the tensor to a list.\n        \"\"\"\n        generation_prompt_ids = (\n            None\n            if self.input_ids[..., -self.generation_prompt_ids.shape[-1] :].eq(self.generation_prompt_ids).all()\n            else self.generation_prompt_ids\n        )\n        if generation_prompt_ids is not None:\n            self._update_input_ids(processing_class, generation_prompt_ids, attention_mask=True, loss_mask=False)\n\n        if self.use_inference_chat_template:\n            messages = [msg.model_dump() for msg in self.messages]\n            tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None\n            generation_prompt_ids = self._handle_apply_chat_template(\n                processing_class,\n                messages,\n                multi_modal_data=self.multi_modal_data,\n                tools=tools,\n                add_generation_prompt=True,\n                tokenize=True,\n            )\n            return generation_prompt_ids.squeeze(0).tolist()\n        else:\n            return self.input_ids.squeeze(0).tolist()\n\n    def add_user_message(\n        self,\n        processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin],\n        content: str,\n    ) -> None:\n        self.messages.append(Message(role=\"user\", content=content))\n        messages = [*BASE_CHAT_HISTORY, self.messages[-1]]\n        tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None\n\n        # We don't need to pass multi_modal_data here because we don't have any multi-modal data from Engine\n        # Inference, it is pure text.\n        content_ids = self._handle_apply_chat_template(\n            processing_class, messages, multi_modal_data={}, tools=tools, add_generation_prompt=False, tokenize=True\n        )[..., self.base_conv_wo_gen_prompt_end_pos :]\n        self._update_input_ids(processing_class, content_ids, attention_mask=True, loss_mask=False)\n\n    def add_assistant_message(\n        self,\n        processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin],\n        content: str,\n        tool_calls: Optional[List[OpenAIFunctionToolCall]] = None,\n    ) -> None:\n        self.messages.append(Message(role=\"assistant\", content=content, tool_calls=tool_calls))\n\n        messages = [*BASE_CHAT_HISTORY, self.messages[-1]]\n        tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None\n\n        # We don't need to pass multi_modal_data here because we don't have any multi-modal data from Engine\n        # Inference, it is pure text.\n        content_ids = self._handle_apply_chat_template(\n            processing_class, messages, multi_modal_data={}, tools=tools, add_generation_prompt=False, tokenize=True\n        )[..., self.base_conv_with_gen_prompt_end_pos :]\n        self._update_input_ids(processing_class, content_ids, attention_mask=True, loss_mask=True)\n\n    def add_tool_response_messages(\n        self,\n        processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin],\n        contents: list[str | Dict[str, Any]],\n    ) -> None:\n        if not contents:\n            return\n        # We also handle the case when tool returns image\n        # We require the processing of the image and video to be done at tool.execute() level\n        delta_multi_modal_data = {key: [] for key in self.multi_modal_keys}\n        for content in contents:\n            if isinstance(content, dict):\n                content_list = []\n                # When we update multi_model_keys, we also need to update this logic\n                if \"image\" in content:\n                    if not isinstance(content[\"image\"], list):\n                        raise ValueError(\n                            f\"Image must be a list, but got {type(content['image'])}. Please check the tool.execute(). \"\n                            f\"For single images, wrap in a list: [image]. \"\n                            f\"Example: {{'image': [img1]}} or {{'image': [img1, img2, ...]}}.\"\n                        )\n\n                    content_list.extend([{\"type\": \"image\"} for _ in content[\"image\"]])\n                    delta_multi_modal_data[\"image\"].extend(content[\"image\"])\n                if \"video\" in content:\n                    if not isinstance(content[\"video\"], list):\n                        raise ValueError(\n                            f\"Video must be a list, but got {type(content['video'])}. Please check the tool.execute(). \"\n                            f\"For single videos, wrap in a list: [video]. \"\n                            f\"Example: {{'video': [video1]}} or {{'video': [video1, video2, ...]}}.\"\n                        )\n\n                    content_list.extend([{\"type\": \"video\"} for _ in content[\"video\"]])\n                    delta_multi_modal_data[\"video\"].extend(content[\"video\"])\n                if \"text\" in content:\n                    content_list.append({\"type\": \"text\", \"text\": content[\"text\"]})\n                for key in content:\n                    if key not in [\"image\", \"video\", \"text\"]:\n                        logger.warning(\n                            f\"Tool response message contains unexpected key: {key} \"\n                            f\"while we only support `image`, `video`, and `text`.\"\n                        )\n                self.messages.append(Message(role=\"tool\", content=content_list))\n            else:\n                self.messages.append(Message(role=\"tool\", content=content))\n\n        messages = [*BASE_CHAT_HISTORY, *self.messages[-len(contents) :]]\n        tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None\n\n        for key in self.multi_modal_keys:\n            if len(delta_multi_modal_data[key]) > 0:\n                self.multi_modal_data[key].extend(delta_multi_modal_data[key])\n\n        # We just passed the new multi-modal data to the chat template to update the input_ids.\n        content_info = self._handle_apply_chat_template(\n            processing_class,\n            messages,\n            multi_modal_data=delta_multi_modal_data,\n            tools=tools,\n            add_generation_prompt=False,\n            tokenize=True,\n            return_dict=True,\n        )\n        content_ids = content_info[\"input_ids\"][..., self.base_conv_wo_gen_prompt_end_pos :]\n\n        # process multi_modal_inputs\n        multi_modal_inputs = content_info.copy()\n        multi_modal_inputs.pop(\"input_ids\", None)\n        multi_modal_inputs.pop(\"attention_mask\", None)\n        self._update_input_ids(\n            processing_class,\n            content_ids,\n            attention_mask=True,\n            loss_mask=False,\n            new_multi_modal_inputs=multi_modal_inputs,\n        )\n\n    def update_metrics(self, metrics: Any, tool_id: str) -> None:\n        \"\"\"\n        metrics: should be a dict of tools_name -> Any\n        \"\"\"\n        if self.metrics.get(tool_id) is None:\n            self.metrics[tool_id] = []\n        self.metrics[tool_id].append(metrics)\n\n    def _get_prompt_diffs(\n        self,\n        processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin],\n        full_prompt_ids: torch.Tensor,\n        current_prompt_ids: torch.Tensor,\n        diff_surrounding_chars: int = 10,\n    ) -> List[Dict[str, Any]]:\n        \"\"\"Get differences between full prompt and current prompt with surrounding context.\n\n        This function helps debug tokenization mismatches by showing the differences between\n        full prompt and current prompt with surrounding context. Instead of just showing\n        the exact diff, it includes additional tokens before and after to help locate\n        the issue in the chat template.\n\n        For example, if the actual diff is a newline change from \"\\n\\n\" to \"\\n\", with\n        diff_surrounding_chars the output might look like:\n\n        full_prompt_chunk:    \"<|im_start|>assistant\\n\\nI think...\"\n        current_prompt_chunk: \"<|im_start|>assistant\\nI think...\"\n\n        This context makes it much easier to identify where in the chat template the\n        mismatch occurs.\n\n        Args:\n            processing_class: The processing class to use for decoding the token IDs\n            full_prompt_ids: Token IDs from applying chat template to all messages at once\n            current_prompt_ids: Token IDs from incremental chat template application\n            diff_surrounding_chars: Number of surrounding characters to include for context (default: 10)\n\n        Returns:\n            List of dicts containing the differing chunks with context and their indices\n        \"\"\"\n        full_prompt_ids = full_prompt_ids.squeeze(0)\n        current_prompt_ids = current_prompt_ids.squeeze(0)\n        full_prompt = processing_class.decode(full_prompt_ids, skip_special_tokens=False)\n        current_prompt = processing_class.decode(current_prompt_ids, skip_special_tokens=False)\n        s = difflib.SequenceMatcher(None, full_prompt, current_prompt, autojunk=False)\n        diffs = []\n        for tag, i1, i2, j1, j2 in s.get_opcodes():\n            if tag == \"equal\":\n                continue\n\n            # Get the surrounding context for better readability\n            start_i = max(0, i1 - diff_surrounding_chars)\n            end_i = min(len(full_prompt), i2 + diff_surrounding_chars)\n            start_j = max(0, j1 - diff_surrounding_chars)\n            end_j = min(len(current_prompt), j2 + diff_surrounding_chars)\n\n            diffs.append(\n                {\n                    \"full_prompt_chunk\": full_prompt[start_i:end_i],\n                    \"current_prompt_chunk\": current_prompt[start_j:end_j],\n                    \"indices\": (start_i, end_i, start_j, end_j),\n                }\n            )\n        return diffs\n\n    def finalize(\n        self,\n        processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin],\n        reward_scores: Dict[str, List[float]],\n        finish_reason_type: FinishReasonTypeEnum = FinishReasonTypeEnum.STOP,\n    ) -> None:\n        self.state = AsyncRolloutRequestStateEnum.COMPLETED\n        self.reward_scores = reward_scores\n\n        # In case we failed to generate the assistant message and the generation prompt ids were already added to\n        # input_ids, remove them from the end of input_ids\n        if self.input_ids[..., -self.generation_prompt_ids.shape[-1] :].eq(self.generation_prompt_ids).all():\n            self.input_ids = self.input_ids[..., : -self.generation_prompt_ids.shape[-1]]\n            self.attention_mask = self.attention_mask[..., : -self.generation_prompt_ids.shape[-1]]\n            self.position_ids = self.position_ids[..., : -self.generation_prompt_ids.shape[-1]]\n            self.loss_mask = self.loss_mask[..., : -self.generation_prompt_ids.shape[-1]]\n\n        self.response_ids = self.input_ids[..., self.prompt_ids.shape[-1] :]\n\n        if self.tokenization_sanity_check_mode != TokenizationSanityCheckModeEnum.DISABLE:\n            # When there is a diff, we log the diffs with diff_surrounding_chars context\n            diff_surrounding_chars = 10\n\n            messages = [msg.model_dump() for msg in self.messages]\n            tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None\n            full_prompt_info = self._handle_apply_chat_template(\n                processing_class,\n                messages,\n                multi_modal_data=self.multi_modal_data,\n                tools=tools,\n                add_generation_prompt=False,\n                tokenize=True,\n                return_dict=True,\n            )\n            full_prompt_ids = full_prompt_info[\"input_ids\"]\n\n            # We must use dict(full_prompt_info) to convert BatchFeature values to a new dict\n            # because np.array() only keeps the keys for BatchFeature.\n            full_prompt_multi_modal_inputs = full_prompt_info.copy()\n            full_prompt_multi_modal_inputs.pop(\"input_ids\", None)\n            full_prompt_multi_modal_inputs.pop(\"attention_mask\", None)\n\n            for multi_modal_inputs_key in self.multi_modal_inputs:\n                if multi_modal_inputs_key in full_prompt_multi_modal_inputs:\n                    if (\n                        not self.multi_modal_inputs[multi_modal_inputs_key]\n                        .eq(full_prompt_multi_modal_inputs[multi_modal_inputs_key])\n                        .all()\n                    ):\n                        logger.warning(\n                            f\"Multi-modal data {multi_modal_inputs_key} is not consistent. \"\n                            f\"This may lead to unexpected behavior during training. \"\n                            f\"Please review your multi_modal_inputs logic.\"\n                        )\n                else:\n                    logger.warning(\n                        f\"Multi-modal inputs key {multi_modal_inputs_key} is not found in the multi_modal_inputs. \"\n                        f\"This may lead to unexpected behavior during training.\"\n                        f\"Please review your multi_modal_inputs logic.\"\n                    )\n\n            if diffs := self._get_prompt_diffs(\n                processing_class, full_prompt_ids, self.input_ids, diff_surrounding_chars=diff_surrounding_chars\n            ):\n                log_warning = False\n                if self.tokenization_sanity_check_mode == TokenizationSanityCheckModeEnum.STRICT:\n                    log_warning = True\n                elif self.tokenization_sanity_check_mode == TokenizationSanityCheckModeEnum.IGNORE_STRIPPABLE:\n                    non_strippable_diffs_exist = any(\n                        d[\"full_prompt_chunk\"].strip() or d[\"current_prompt_chunk\"].strip() for d in diffs\n                    )\n                    if non_strippable_diffs_exist:\n                        log_warning = True\n\n                if log_warning:\n                    mode_str = f\" ({self.tokenization_sanity_check_mode.value})\"\n                    logger.warning(\n                        f\"Inconsistent training and inference tokenization detected{mode_str}. This may lead to \"\n                        f\"unexpected behavior during training. Please review your chat template to determine if this \"\n                        f\"is intentional. For more information, refer to the multiturn README.md.\"\n                    )\n                    logger.warning(\n                        f\"Showing {diff_surrounding_chars} characters before and after the diffs for context and \"\n                        f\"better readability.\"\n                    )\n                    diff_details_list = []\n                    for d in diffs:\n                        i1, i2, j1, j2 = d[\"indices\"]\n                        diff_details_list.append(\n                            f\"idx {i1}:{i2} -> {j1}:{j2} | full_prompt_chunk: {repr(d['full_prompt_chunk'])} | \"\n                            f\"current_prompt_chunk: {repr(d['current_prompt_chunk'])}\"\n                        )\n                    diff_details = \"\\n\".join(diff_details_list)\n                    logger.warning(f\"Found differences:\\n{diff_details}\")\n\n        if finish_reason_type == FinishReasonTypeEnum.STOP:\n            pass\n        elif finish_reason_type == FinishReasonTypeEnum.LENGTH:\n            pass\n        else:\n            raise ValueError(f\"Unsupported finalize finish reason type: {finish_reason_type}\")\n        self.truncate_output_ids(processing_class)\n\n        assert (\n            self.input_ids.shape[-1]\n            == self.attention_mask.shape[-1]\n            == self.position_ids.shape[-1]\n            == self.loss_mask.shape[-1]\n        ), f\"\"\"Request {self.request_id} has different length of {self.input_ids.shape[-1]=}, \n            {self.attention_mask.shape[-1]=}, {self.position_ids.shape[-1]=}, {self.loss_mask.shape[-1]=}\"\"\"\n\n    def truncate_output_ids(\n        self, processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin]\n    ) -> None:\n        self.input_ids = self.input_ids[..., : self.max_model_len]\n        self.attention_mask = self.attention_mask[..., : self.max_model_len]\n        self.position_ids = self.position_ids[..., : self.max_model_len]\n        self.loss_mask = self.loss_mask[..., : self.max_model_len]\n        self.response_ids = self.input_ids[..., self.prompt_ids.shape[-1] :][..., : self.max_response_len]\n        self.response_attention_mask = self.attention_mask[..., self.prompt_attention_mask.shape[-1] :][\n            ..., : self.max_response_len\n        ]\n        self.response_position_ids = self.position_ids[..., self.prompt_position_ids.shape[-1] :][\n            ..., : self.max_response_len\n        ]\n        self.response_loss_mask = self.loss_mask[..., self.prompt_loss_mask.shape[-1] :][..., : self.max_response_len]\n"
  },
  {
    "path": "siirl/engine/rollout/sglang_rollout/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n\nfrom .sglang_rollout import SGLangRollout\n\n__all__ = [\"SGLangRollout\"]\n"
  },
  {
    "path": "siirl/engine/rollout/sglang_rollout/async_sglang_server.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport logging\nfrom typing import Any, Dict, List, Tuple\nimport pickle\nimport zmq\nimport torch\nimport ray\nfrom omegaconf import DictConfig\nfrom starlette.requests import Request\nfrom starlette.responses import JSONResponse\n\nfrom siirl.engine.rollout.async_server import AsyncServerBase\nfrom siirl.params.model_args import ActorRolloutRefArguments\nfrom siirl.engine.rollout.sglang_rollout import SGLangRollout\nlogger = logging.getLogger(__file__)\n\n\nclass AsyncSglangServer(AsyncServerBase):\n    def __init__(self, config: ActorRolloutRefArguments, spmd_engine: SGLangRollout, zmq_addresses:List):\n        super().__init__()\n        self.config = config.rollout\n        self.workers_zmq = []\n        self.master_worker_zmq = None\n        self.engine = spmd_engine\n        self.zmq_addresses = zmq_addresses\n        \n    async def init_engine(self):\n        self.context = zmq.Context()\n        for zmq_address in self.zmq_addresses:\n            socket = self.context.socket(zmq.REQ)\n            socket.connect(zmq_address)\n            self.workers_zmq.append(socket)\n        self.master_worker_zmq = self.workers_zmq[0]\n    async def chat_completion(self, raw_request: Request):\n        request = await raw_request.json()\n        message = pickle.dumps(('chat_completion', (), {'request':request}))\n        self.master_worker_zmq.send(message, zmq.DONTWAIT)\n        outputs = []\n        outputs.append(pickle.loads(self.master_worker_zmq.recv()))\n        return JSONResponse(outputs)\n\n\n    async def generate(self, prompt_ids: List[int], sampling_params: Dict[str, Any], request_id: str) -> List[int]:\n        return await self.engine.generate(prompt_ids, sampling_params, request_id)\n\n    async def wake_up(self):\n        if not self.config.free_cache_engine:\n            return\n        message = pickle.dumps(('wake_up', (), {}))\n        for socket in self.workers_zmq:\n            socket.send(message, zmq.DONTWAIT)\n        for socket in self.workers_zmq:\n            socket.recv()\n        return\n\n    async def sleep(self):\n        if not self.config.free_cache_engine:\n            return\n        message = pickle.dumps(('sleep', (), {}))\n        for socket in self.workers_zmq:\n            socket.send(message, zmq.DONTWAIT)\n        for socket in self.workers_zmq:\n            socket.recv()\n        return\n"
  },
  {
    "path": "siirl/engine/rollout/sglang_rollout/sglang_rollout.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom __future__ import annotations\n\nimport asyncio\nimport logging\nimport multiprocessing as mp\nimport os\nimport time\nfrom copy import deepcopy\nfrom json import JSONDecodeError\nfrom typing import Any, List, Optional, Tuple, Union\nfrom uuid import uuid4\nimport pickle\nimport socket\nimport threading\nimport ray\nimport zmq\nfrom filelock import FileLock\nimport numpy as np\nimport sglang.srt.entrypoints.engine\nimport torch\nimport torch.distributed as dist\nfrom omegaconf import DictConfig\nfrom sglang.srt.managers.tokenizer_manager import (\n    ReleaseMemoryOccupationReqInput,\n    ResumeMemoryOccupationReqInput,\n    UpdateWeightsFromTensorReqInput,\n)\nfrom sglang.srt.openai_api.protocol import Tool\nfrom sglang.srt.sampling.sampling_params import SamplingParams\nfrom sglang.srt.server_args import ServerArgs\nfrom sglang.srt.utils import (\n    MultiprocessingSerializer,\n    assert_pkg_version,\n    get_ip,\n    get_open_port,\n    is_cuda,\n    maybe_set_triton_cache_manager,\n    set_prometheus_multiproc_dir,\n    set_ulimit,\n)\nfrom tensordict import TensorDict\nfrom torch.distributed.device_mesh import DeviceMesh, init_device_mesh\nfrom torch.nn.utils.rnn import pad_sequence\nfrom transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin\n\nfrom siirl.execution.rollout_flow.multiturn.interactions.base import BaseInteraction\nfrom siirl.execution.rollout_flow.multiturn.interactions.utils.interaction_registry import initialize_interactions_from_config\nfrom siirl.third_party.sglang import parallel_state as sglang_ps\nfrom siirl.execution.rollout_flow.multiturn.tools.base_tool import BaseTool\nfrom siirl.execution.rollout_flow.multiturn.tools.schemas import OpenAIFunctionCallSchema, OpenAIFunctionParsedSchema, OpenAIFunctionToolCall\nfrom siirl.execution.rollout_flow.multiturn.tools.utils.tool_registry import initialize_tools_from_config\n\n\nfrom siirl.utils.extras.net_utils import is_ipv6\nfrom siirl.utils.debug import GPUMemoryLogger\nfrom siirl.utils.model_utils.torch_functional import get_response_mask, pad_sequence_to_length\nfrom siirl.engine.rollout.base import BaseRollout\n\n\nfrom siirl.engine.rollout.schemas import (\n    AsyncRolloutRequest,\n    AsyncRolloutRequestStateEnum,\n    FinishReasonTypeEnum,\n    Message,\n)\n\nfrom siirl.params import RolloutArguments\nfrom siirl.engine.rollout.sglang_rollout.utils import broadcast_pyobj\nfrom loguru import logger\n\ntry:\n    from sglang.srt.function_call.function_call_parser import FunctionCallParser\nexcept ImportError:\n    from sglang.srt.function_call_parser import FunctionCallParser\n\n\n# patch to avoid issue https://github.com/sgl-project/sglang/issues/6723\ndef _set_envs_and_config(server_args: ServerArgs):\n    # Set global environments\n    os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n    os.environ[\"NCCL_CUMEM_ENABLE\"] = \"0\"\n    os.environ[\"NCCL_NVLS_ENABLE\"] = str(int(server_args.enable_nccl_nvls))\n    os.environ[\"TORCH_NCCL_AVOID_RECORD_STREAMS\"] = \"1\"\n    os.environ[\"CUDA_DEVICE_MAX_CONNECTIONS\"] = \"4\"\n    os.environ[\"CUDA_MODULE_LOADING\"] = \"AUTO\"\n\n    # Set prometheus env vars\n    if server_args.enable_metrics:\n        set_prometheus_multiproc_dir()\n\n    # Set ulimit\n    set_ulimit()\n\n    # Fix triton bugs\n    if server_args.tp_size * server_args.dp_size > 1:\n        # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.\n        maybe_set_triton_cache_manager()\n\n    # Check flashinfer version\n    if server_args.attention_backend == \"flashinfer\":\n        assert_pkg_version(\n            \"flashinfer_python\",\n            \"0.2.5\",\n            \"Please uninstall the old version and reinstall the latest version by following the instructions at https://docs.flashinfer.ai/installation.html.\",\n        )\n    if is_cuda():\n        assert_pkg_version(\n            \"sgl-kernel\",\n            \"0.1.1\",\n            \"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`\",\n        )\n\n    # Set mp start method\n    mp.set_start_method(\"spawn\", force=True)\n\n\nsglang.srt.entrypoints.engine._set_envs_and_config = _set_envs_and_config\n\n\n# because chatCompletion is an async method, it makes the whole ray actor be an async actor\n# which can not call loop.run_until_complete. So we need to make the engine to be an async class\nclass AsyncEngine(sglang.srt.entrypoints.engine.Engine):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        # default to use dummy load format, which need to reload weights in first time\n        self._need_reload = True\n\n    async def release_memory_occupation(self, tags: Optional[list[str]] = None):\n        \"\"\"Release GPU occupation temporarily.\"\"\"\n        if tags is None:\n            obj = ReleaseMemoryOccupationReqInput()\n        else:\n            obj = ReleaseMemoryOccupationReqInput(tags=tags)\n        return await self.tokenizer_manager.release_memory_occupation(obj, None)\n\n    async def resume_memory_occupation(self, tags: Optional[list[str]] = None):\n        \"\"\"Resume GPU occupation.\"\"\"\n        # because __init__ is a sync method, it can not call the async release_memory_occupation\n        # have to move release_memory_occupation from __init__ to here\n        # For multi-stage awake, we run release weight and kv_cache when we resume weights for the first time.\n        if self._need_reload:\n            await self.release_memory_occupation()\n            self._need_reload = False\n\n        if tags is None:\n            obj = ResumeMemoryOccupationReqInput()\n        else:\n            obj = ResumeMemoryOccupationReqInput(tags=tags)\n        return await self.tokenizer_manager.resume_memory_occupation(obj, None)\n\n    async def update_weights_from_tensor(\n        self,\n        named_tensors: List[Tuple[str, torch.Tensor]],  # noqa: UP006\n        load_format: Optional[str] = None,\n        flush_cache: bool = True,\n    ):\n        \"\"\"Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false\n        to avoid duplicated cache cleaning operation.\"\"\"\n        obj = UpdateWeightsFromTensorReqInput(\n            serialized_named_tensors=[\n                MultiprocessingSerializer.serialize(named_tensors) for _ in range(self.server_args.tp_size)\n            ],\n            load_format=load_format,\n            flush_cache=flush_cache,\n        )\n        return await self.tokenizer_manager.update_weights_from_tensor(obj, None)\n\n    async def flush_cache(self):\n        return await self.tokenizer_manager.flush_cache()\n\n\n# NOTE(sgm): add for verl. We can optimize it by making\n#  the dataloader yield List[int] without padding.\ndef _pre_process_inputs(\n    pad_token_id,\n    prompt_token_ids: torch.Tensor,\n) -> torch.Tensor:\n    # remove the left padding in the prompt token_id\n    non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]\n    return prompt_token_ids[non_pad_index:]\n\n\n# NOTE(linjunrong): adhoc\ndef _post_process_outputs(processing_class, output):\n    try:\n        # This is when processing_class is a processor\n        tokenizer = processing_class.tokenizer\n    except AttributeError:\n        try:\n            # This is when processing_class is a tokenizer\n            tokenizer = processing_class\n        except AttributeError as e:\n            raise ValueError(f\"Cannot get tokenizer from processing_class {processing_class}\") from e\n\n    def _map_each_response(resp):\n        output_token_logprobs = resp[\"meta_info\"][\"output_token_logprobs\"]\n        log_probs, output_token_ids = zip(*[(log_prob, token_ids) for log_prob, token_ids, _ in output_token_logprobs])\n        return torch.tensor(output_token_ids), torch.tensor(log_probs)\n\n    out_map = map(lambda x: _map_each_response(x), output)\n    batched_output_token_ids = []\n    batched_logprobs = []\n    for output_token_ids, log_probs in out_map:\n        batched_output_token_ids.append(output_token_ids)\n        batched_logprobs.append(log_probs)\n    pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id\n    batched_output_token_ids = pad_sequence(batched_output_token_ids, batch_first=True, padding_value=pad_token_id)\n    if len(batched_logprobs) > 0:\n        batched_logprobs = pad_sequence(batched_logprobs, batch_first=True, padding_value=pad_token_id)\n    return batched_output_token_ids, batched_logprobs\n\n\ndef get_tool_call_parser_type(\n    processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin],\n) -> str:\n    items = FunctionCallParser.ToolCallParserEnum.items()\n    for parser_type, parser_cls in items:\n        parser = parser_cls()\n        try:\n            # This is when processing_class is a tokenizer\n            tokenizer_vocab = processing_class.get_vocab()\n        except AttributeError:\n            try:\n                # This is when processing_class is a processor\n                tokenizer_vocab = processing_class.tokenizer.get_vocab()\n            except AttributeError as e:\n                raise ValueError(f\"Cannot get vocab from processing_class {processing_class}\") from e\n\n        if parser.bot_token.strip() in tokenizer_vocab and (\n            parser.eot_token == \"\" or parser.eot_token.strip() in tokenizer_vocab\n        ):\n            return parser_type\n    else:\n        raise ValueError(f\"No tool call parser found for processing_class {processing_class}\")\n\n\nclass SGLangRollout(BaseRollout):\n    def __init__(\n        self,\n        actor_module: str,\n        config: RolloutArguments,\n        processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin],\n        model_hf_config,\n        port=None,\n        trust_remote_code: bool = False,\n        device_mesh: DeviceMesh | None = None,\n        **kwargs,\n    ):\n        \"\"\"Synchronized SGLang rollout engine.\n\n        Args:\n            actor_module: Huggingface model name or path to the model. The\n                model should be supported by SGLang.\n            config: A DictConfig object containing SGLang-specific operational\n                parameters and rollout settings.\n                Refer to https://docs.sglang.ai/backend/server_arguments.html\n            processing_class: The tokenizer or processor instance compatible with the actor_module.\n            model_hf_config: The Hugging Face model's configuration (e.g.,\n                `transformers.PretrainedConfig`). It provides architectural\n                details and hyperparameters like `max_position_embeddings`,\n                used by SGLang for correct model initialization. This is\n                the model's inherent design, not SGLang's runtime behavior.\n            port: Optional port for multi-node initialization when nnodes > 1.\n            trust_remote_code: Whether or not to allow for custom models\n                defined on the Hub in their own modeling files.\n            device_mesh: Optional `DeviceMesh` object for distributed setup.\n            **kwargs: Additional keyword arguments, primarily `train_tp` for\n                Megatron Backend integration to initialize hybrid engine\n                process groups.\n        \"\"\"\n        super().__init__()\n        self.config = config\n        self._device_mesh_cpu = device_mesh\n        os.environ.setdefault(\"SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK\", \"true\")\n\n        (\n            self._tool_schemas,\n            self._tool_map,\n            self._tool_call_parser_type,\n            self._sgl_tools,\n            self._function_call_parser,\n        ) = self._initialize_tools(config, processing_class)\n        self.interaction_map: dict[str, BaseInteraction] = self._initialize_interactions(config)\n        # If turn on `free_cache_engine`, SGLang engine's KV cache\n        # will be freed after each `generate_sequences` call.\n        logger.info(\n            f\"tool_schemas: {self._tool_schemas}, tool_map: {self._tool_map}, tool_call_parser_type: \"\n            f\"{self._tool_call_parser_type}, sgl_tools: {self._sgl_tools}, function_call_parser: \"\n            f\"{self._function_call_parser}\"\n        )\n\n        self._init_distributed_env(device_mesh_cpu=device_mesh, **kwargs)\n\n        self._verify_config(model_hf_config=model_hf_config)\n        # initialize the inference engine\n        self._init_inference_engine(trust_remote_code, actor_module, port)\n\n        self._init_sampling_params(**kwargs)\n\n        self.processing_class = processing_class\n\n        if self.config.mode == 'async':\n            self.address = self._init_zeromq()\n        try:\n            # This is when processing_class is a tokenizer\n            self.pad_token_id = self.processing_class.pad_token_id\n        except AttributeError:\n            try:\n                # This is when processing_class is a processor\n                self.pad_token_id = self.processing_class.tokenizer.pad_token_id\n            except AttributeError as e:\n                raise ValueError(f\"Cannot get pad_token_id from processing_class {self.processing_class}\") from e\n\n    def _init_distributed_env(self, device_mesh_cpu, **kwargs):\n        self._device_mesh_cpu = device_mesh_cpu\n        os.environ.setdefault(\"SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK\", \"true\")\n        self.tensor_parallel_size = self.config.tensor_model_parallel_size\n        assert self.tensor_parallel_size <= dist.get_world_size(), (\n            \"tensor parallel size should be less than or equal to the world size\"\n        )\n        self.train_tp = kwargs.get(\"train_tp\", None)\n        if self.train_tp is not None:\n            # deployed with megatron\n            os.environ[\"CUDA_TIMER_STREAM_KAFKA_ENABLE\"] = \"0\"\n            os.environ[\"MEGATRON_IMPORT_TIMERS\"] = \"0\"\n            train_tp = kwargs.get(\"train_tp\", None)\n            num_tp_per_train_tp = train_tp // self.tensor_parallel_size\n            sglang_ps.initialize_parallel_state(\n                tensor_model_parallel_size=self.tensor_parallel_size,\n                num_tp_per_train_tp=num_tp_per_train_tp,\n            )\n\n        tp_size = self.tensor_parallel_size\n        world_size = int(os.getenv(\"WORLD_SIZE\", \"-1\"))\n\n        # init device mesh\n        if self._device_mesh_cpu is None:\n            device_mesh_kwargs = dict(\n                mesh_shape=(world_size // tp_size, tp_size, 1),\n                mesh_dim_names=[\"dp\", \"tp\", \"pp\"],\n            )\n\n            self._device_mesh_cpu = init_device_mesh(\"cpu\", **device_mesh_kwargs)\n\n        self._rank = self._device_mesh_cpu.get_rank()\n        self._tp_rank = self._device_mesh_cpu[\"tp\"].get_local_rank()\n        self._tp_size = self._device_mesh_cpu[\"tp\"].size()\n        if self._rank == 0:\n            logger.info(f\"_init_distributed_env: :tp_world: {self._tp_size}, global_world: {world_size}\")\n        # get tp_rank of this process in this tp group\n        visible_devices = [None] * self._device_mesh_cpu.size(1)\n\n        torch.distributed.all_gather_object(\n            visible_devices, os.environ[\"CUDA_VISIBLE_DEVICES\"], self._device_mesh_cpu.get_group(\"tp\")\n        )\n        self.visible_devices_set = set(\",\".join(visible_devices).split(\",\"))\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join(sorted(list(self.visible_devices_set)))\n\n    def _verify_config(self, model_hf_config):\n        if not self.config.max_model_len:\n            self.config.max_model_len = self.config.prompt_length + self.config.response_length\n        assert (\n            self.config.max_model_len >= self.config.prompt_length + self.config.response_length\n        ), f\"\"\"max_model_len should be greater than total sequence length (prompt_length + response_length): \n            {self.config.max_model_len} >= {self.config.prompt_length} + {self.config.response_length}\"\"\"\n        max_position_embeddings = None\n        if hasattr(model_hf_config, \"max_position_embeddings\"):\n            max_position_embeddings = model_hf_config.max_position_embeddings\n        elif hasattr(model_hf_config, \"llm_config\") and hasattr(model_hf_config.llm_config, \"max_position_embeddings\"):\n            max_position_embeddings = model_hf_config.llm_config.max_position_embeddings\n        elif hasattr(model_hf_config, \"text_config\") and hasattr(\n            model_hf_config.text_config, \"max_position_embeddings\"\n        ):\n            max_position_embeddings = model_hf_config.text_config.max_position_embeddings\n        if max_position_embeddings is None:\n            raise ValueError(\"max_position_embeddings not found in model_hf_config\")\n        rope_scaling_config = getattr(model_hf_config, \"rope_scaling\", None)\n        if not rope_scaling_config:\n            assert max_position_embeddings >= self.config.prompt_length + self.config.response_length, (\n                \"model context length should be greater than total sequence length\"\n            )\n        else:\n            # handle type where there's a length extend factor\n            # see https://qwen.readthedocs.io/en/latest/deployment/vllm.html#extended-context-support\n            # for using yarn as an example\n            rope_scaling_factor = rope_scaling_config.get(\"factor\", 1.0)\n\n            assert (\n                model_hf_config.max_position_embeddings * rope_scaling_factor\n                >= self.config.prompt_length + self.config.response_length\n            ), (\n                f\"model context length should be greater than total sequence length, \"\n                f\"got rope_scaling_factor={rope_scaling_factor} and \"\n                f\"max_position_embeddings={model_hf_config.max_position_embeddings}\"\n            )\n\n        # currently max_assistant_turns stand for max number of tool calls\n\n        if self.config.multi_turn.max_assistant_turns is None:\n            self.config.multi_turn.max_assistant_turns = self.config.max_model_len // 3\n        if self.config.multi_turn.max_user_turns is None:\n            self.config.multi_turn.max_user_turns = self.config.max_model_len // 3\n\n\n    def _init_inference_engine(self, trust_remote_code, actor_module, port):\n        # initialize the inference engine\n        nnodes = -(-self._tp_size // len(self.visible_devices_set))\n        if nnodes > 1:\n            ip = get_ip()\n            port = get_open_port() if port is None else port\n            [ip, port] = broadcast_pyobj(\n                [ip, port],\n                rank=self._rank,\n                dist_group=self._device_mesh_cpu.get_group(\"tp\"),\n                src=self._device_mesh_cpu[\"tp\"].mesh[0].item(),\n                force_cpu_device=False,\n            )\n            dist_init_addr = f\"[{ip}]:{port}\" if is_ipv6(ip) else f\"{ip}:{port}\"\n        else:\n            dist_init_addr = None\n\n        load_format = \"dummy\" if self.config.load_format.startswith(\"dummy\") else self.config.load_format\n        tp_size_per_node = self._tp_size // nnodes\n        node_rank = self._tp_rank // tp_size_per_node\n        first_rank_in_node = self._tp_rank % tp_size_per_node == 0\n\n        if first_rank_in_node:\n            rank = dist.get_rank()\n            os.environ[\"SGLANG_BLOCK_NONZERO_RANK_CHILDREN\"] = \"0\"\n            self.inference_engine = AsyncEngine(\n                model_path=actor_module,\n                dtype=self.config.dtype,\n                mem_fraction_static=self.config.gpu_memory_utilization,\n                enable_memory_saver=True,\n                base_gpu_id=0,\n                gpu_id_step=1,\n                tp_size=self._tp_size,\n                node_rank=node_rank,\n                load_format=load_format,\n                dist_init_addr=dist_init_addr,\n                nnodes=nnodes,\n                trust_remote_code=trust_remote_code,\n                # NOTE(linjunrong): add rank to prevent SGLang generate same port inside PortArgs.init_new\n                # when random.seed is being set during training\n                port=30000 + rank,\n                # NOTE(Chenyang): if you want to debug the SGLang engine output\n                # please set the following parameters\n                # Otherwise, it will make the engine run too slow\n                # log_level=\"INFO\",\n                # log_requests=True,\n                # log_requests_level=2,\n                # max_running_requests=1,\n                mm_attention_backend=\"fa3\",\n                attention_backend=\"fa3\",\n                # In async mode, we want token in token out.\n                skip_tokenizer_init=self.config.mode == \"async\",\n            )\n        else:\n            self.inference_engine = None\n\n        self.sharding_manager = None\n        self.is_sleep = True\n\n    def _init_sampling_params(self, **kwargs):\n        kwargs = dict(\n            n=1,\n            max_new_tokens=self.config.response_length,\n            presence_penalty=0.0,\n            frequency_penalty=0.0,\n            repetition_penalty=1.0,\n        )\n        # supporting adding any sampling params from the config file\n        dictConfig = self.config.to_dict()\n        for k in dictConfig.keys():\n            if hasattr(SamplingParams(), str(k)) or \"stop\" in str(k):\n                kwargs[k] = dictConfig.get(k)\n        kwargs['n'] = 1\n        self.sampling_params = kwargs\n        \n    def _initialize_tools(self, config, processing_class):\n        \"\"\"Initialize tools from configuration.\n        Args:\n            config: Configuration object containing tool-related settings,\n                    specifically `config.multi_turn.tool_config_path`.\n            tokenizer: The tokenizer instance used for parsing tool calls from\n                       the model's generated text.\n        Returns:\n            tuple: A tuple containing:\n                - tool_schemas (list[dict]): OpenAI-formatted JSON schemas\n                  defining each tool's capabilities.\n                - tool_map (dict[str, BaseTool]): A dictionary mapping tool\n                  names to their executable `BaseTool` objects.\n                - tool_call_parser_type (str): The identifier for the specific\n                  parser type (e.g., 'json_mode', 'tool_code') used to extract\n                  tool calls.\n                - sgl_tools (list[sglang.srt.openai_api.protocol.Tool]): Tool\n                  definitions optimized for SGLang's internal engine.\n                - function_call_parser (sglang.srt.function_call_parser.FunctionCallParser):\n                  The active parser instance responsible for extracting\n                  structured tool calls from model outputs.\n        \"\"\"\n        if config.multi_turn.tool_config_path is None:\n            return [], {}, None, [], None\n\n        tools_config_file = config.multi_turn.tool_config_path\n        tool_list = initialize_tools_from_config(tools_config_file)\n\n        logger.info(f\"Initialize tools from configuration.: tool_list: {tool_list}\")\n        tool_schemas = [tool.get_openai_tool_schema().model_dump() for tool in tool_list]\n        tool_map = {tool.name: tool for tool in tool_list}\n        tool_call_parser_type = get_tool_call_parser_type(processing_class)\n        sgl_tools = [Tool.model_validate(tool_schema) for tool_schema in tool_schemas]\n        function_call_parser = FunctionCallParser(\n            sgl_tools,\n            tool_call_parser_type,\n        )\n\n        return (\n            tool_schemas,\n            tool_map,\n            tool_call_parser_type,\n            sgl_tools,\n            function_call_parser,\n        )\n\n    def _initialize_interactions(self, config):\n        \"\"\"Initialize interactions from configuration.\n\n        Returns:\n            dict[str, BaseInteraction]: A dictionary mapping interaction names to interaction instances.\n        \"\"\"\n        if config.multi_turn.interaction_config_path is None:\n            return {}\n\n        interaction_config_file = config.multi_turn.interaction_config_path\n        interaction_map = initialize_interactions_from_config(interaction_config_file)\n\n        logger.info(f\"Initialize interactions from configuration: interaction_map: {list(interaction_map.keys())}\")\n        return interaction_map\n\n    @GPUMemoryLogger(role=\"sglang rollout\", logger=logger)\n    @torch.no_grad()\n    def generate_sequences(self, prompts: TensorDict, **kwargs) -> TensorDict:\n        \"\"\"Generate sequences for a batch of prompts.\n\n        Args:\n            batch (TensorDict): Input batch.\n\n        Returns:\n            TensorDict: Output batch.\n            - prompts: [bsz, prompt_length], prompt token ids from dataset.\n            - responses: [bsz, response_length], output token ids include response tokens\n              from LLM generation and observation tokens from tool_calls.\n            - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens.\n            - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens\n              and response tokens.\n            - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens.\n            - position_ids: [bsz, prompt_length + response_length], incremental position ids.\n\n            For multi-turn conversations:\n            responses:     |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->|\n            response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0|\n        \"\"\"\n        if self.config.multi_turn.enable:\n            return self._req_level_generate_sequences(prompts, **kwargs)\n        return self._batch_level_generate_sequences(prompts, **kwargs)\n\n    @GPUMemoryLogger(role=\"sglang rollout\", logger=logger)\n    @torch.no_grad()\n    def _batch_level_generate_sequences(self, prompts: TensorDict, **kwargs) -> TensorDict:\n        \"\"\"Generates single-turn sequences for a batch of prompts.\n        For single-turn generation, all prompts are processed in one request.\n        `_batch_level_generate_sequences` involves:\n        1.  Extracting and pre-processing prompt token IDs from the input\n            `prompts`. This includes handling padding and preparing raw\n            token ID lists.\n        2.  Preparing inputs for the SGLang engine, including multi-modal\n            data if present.\n        3.  Invoking the SGLang engine (`self.inference_engine.async_generate`,\n            an async coroutine) with the batch of processed inputs and\n            specified sampling parameters on the master TP rank.\n        4.  Broadcasting the results from the master TP rank to all\n            other TP ranks.\n        5.  Post-processing the engine's output to format the generated\n            token IDs and (if applicable) log probabilities.\n        6.  Constructing the final sequences by concatenating original\n            prompts with the generated responses.\n        7.  Updating attention masks and position IDs to reflect the full\n            concatenated sequences.\n        8.  If `self.config.free_cache_engine` is true, the SGLang engine's\n            KV cache is flushed after generation on the master TP rank.\n        Args:\n            prompts: A `TensorDict` object containing the batch of\n              input prompts, including tensor data (like `input_ids`,\n              `attention_mask`) and meta-information (like `eos_token_id`,\n              `do_sample`).\n            **kwargs: Additional keyword arguments that can override the\n              default sampling parameters (e.g., `temperature`, `top_p`,\n              `max_new_tokens`). These are temporarily applied using\n              `update_sampling_params`.\n        Returns:\n            TensorDict: A `TensorDict` object containing the batch of\n              generated sequences. This includes tensors for `prompts`\n              (original input IDs), `responses` (generated token IDs),\n              `input_ids` (concatenated prompt and response),\n              `attention_mask`, and `position_ids` for the full\n              sequences.\n        Note that in GRPO, if the prompts are validated, we repeat the prompts for rollout.n times in ray_trainer.\n        Thus we do not need to repeat the prompts here and set the sampling parameter n to 1.\n        \"\"\"     \n        # input ids: (bs, prompt_length), left-padded\n        idx = prompts[\"input_ids\"]\n        # attention_mask: (bs, seq_length), left-padded\n        attention_mask = prompts[\"attention_mask\"]\n        position_ids = prompts[\"position_ids\"]\n\n        # used to generate attention mask for the\n        # response based on EOS token position\n        eos_token_id = prompts[\"eos_token_id\"]\n\n        batch_size = idx.size(0)\n\n        # Extract non-tensor data\n        if \"raw_prompt_ids\" not in prompts:\n            prompts[\"raw_prompt_ids\"] = np.array(\n                [_pre_process_inputs(self.pad_token_id, idx[i]).tolist() for i in range(batch_size)],\n                dtype=object,\n            )\n\n        if \"multi_modal_data\" in prompts:\n            sglang_inputs = []\n            for raw_prompt_ids, multi_modal_data in zip(\n                prompts.pop(\"raw_prompt_ids\").data,\n                prompts.pop(\"multi_modal_data\").data,\n            ):\n                sglang_inputs.append(\n                    {\n                        \"prompt_token_ids\": raw_prompt_ids,\n                        \"multi_modal_data\": multi_modal_data,\n                        \"image_data\": (\n                            multi_modal_data.get(\"image\", None) if isinstance(multi_modal_data, dict) else None\n                        ),\n                    }\n                )\n        else:\n            sglang_inputs = [\n                {\"prompt_token_ids\": raw_prompt_ids} for raw_prompt_ids in prompts.pop(\"raw_prompt_ids\").data\n            ]\n        \n        # Ensure token IDs are lists or numpy arrays\n        for input_data in sglang_inputs:\n            if isinstance(input_data[\"prompt_token_ids\"], np.ndarray):\n                input_data[\"prompt_token_ids\"] = input_data[\"prompt_token_ids\"].tolist()\n            elif not isinstance(input_data[\"prompt_token_ids\"], list):\n                raise TypeError(\n                    f\"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}\"\n                )\n        \n        # Extract token IDs and image data for SGLang Engine\n        idx_list = [input_data[\"prompt_token_ids\"] for input_data in sglang_inputs]\n        image_list = [input_data.get(\"image_data\", None) for input_data in sglang_inputs]\n\n        do_sample = prompts[\"do_sample\"] if \"do_sample\" in prompts else True\n        is_validate = prompts[\"validate\"] if \"validate\" in prompts else False\n\n        # Create request-level sampling parameters\n        request_sampling_params = self.sampling_params.copy()\n        if not do_sample:\n            request_sampling_params.update(\n                {\n                    \"n\": 1,\n                    \"presence_penalty\": 0.0,\n                    \"frequency_penalty\": 0.0,\n                    \"repetition_penalty\": 1.0,\n                    \"temperature\": 0,\n                    \"top_p\": 1,\n                    \"top_k\": -1,\n                    \"ignore_eos\": False,\n                    \"min_new_tokens\": 0,\n                    \"max_new_tokens\": self.config.response_length,\n                    \"skip_special_tokens\": True,\n                    \"spaces_between_special_tokens\": True,\n                }\n            )\n        elif is_validate:\n            request_sampling_params.update(\n                {\n                    \"top_k\": self.config.val_kwargs.top_k,\n                    \"top_p\": self.config.val_kwargs.top_p,\n                    \"temperature\": self.config.val_kwargs.temperature,\n                    \"n\": 1,  # if validate, already repeat in ray_trainer\n                }\n            )\n\n        # Update with any additional kwargs\n        request_sampling_params.update(kwargs)\n        if self._tp_rank == 0:\n            loop = asyncio.get_event_loop()\n            output = loop.run_until_complete(\n                self.inference_engine.async_generate(\n                    prompt=None,  # because we have already convert it to prompt token id\n                    sampling_params=request_sampling_params,\n                    return_logprob=True,\n                    input_ids=idx_list,\n                    image_data=image_list,\n                )\n            )\n        else:\n            output = None\n\n        # Most naive implementation, can extract tensor and send via gloo if too slow\n        dist.barrier()\n        [output] = broadcast_pyobj(\n            data=[output],\n            rank=self._rank,\n            dist_group=self._device_mesh_cpu[\"tp\"].get_group(),\n            src=self._device_mesh_cpu[\"tp\"].mesh[0].item(),\n            force_cpu_device=False,\n        )\n        out = _post_process_outputs(self.processing_class, output)\n\n        response = out[0].to(idx.device)\n\n        if response.shape[1] < self.config.response_length:\n            response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)\n\n        seq = torch.cat([idx, response], dim=-1)\n\n        response_length = response.size(1)\n        delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)\n        delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)\n        if position_ids.dim() == 3:  # qwen2vl mrope\n            delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)\n\n        # TODO(sgm): fix position_ids on right_pad\n        # prompt: left pad + response: right pad\n        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]\n        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]\n        response_position_ids = position_ids[..., -1:] + delta_position_id\n        position_ids = torch.cat([position_ids, response_position_ids], dim=-1)\n        response_attention_mask = get_response_mask(\n            response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype\n        )\n        attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)\n\n        # all the tp ranks should contain the same data here. data in all ranks are valid\n        prompts[\"prompts\"] = idx\n        prompts[\"responses\"] =  response\n        prompts[\"input_ids\"] = seq\n        prompts[\"attention_mask\"] = attention_mask\n        prompts[\"position_ids\"] = position_ids\n\n        # free cache engine\n        if self.inference_engine is not None and self._tp_rank == 0:\n            loop = asyncio.get_event_loop()\n            loop.run_until_complete(self.inference_engine.flush_cache())\n\n        return prompts\n\n    async def _async_rollout_a_request(\n        self,\n        req: AsyncRolloutRequest,\n        do_sample: bool = True,\n        is_validate: bool = False,\n        **kwargs,\n    ) -> AsyncRolloutRequest:\n        assert self._tp_rank == 0, \"only the master process can call this function\"\n        _req = deepcopy(req)\n        finish_reason_type = None\n        output = None\n\n        current_turns = 0\n        user_turns = 0\n        user_turn_rewards = []\n        # Create request-level sampling parameters\n        request_sampling_params = self.sampling_params.copy()\n        if not do_sample:\n            request_sampling_params.update(\n                {\n                    \"n\": 1,\n                    \"presence_penalty\": 0.0,\n                    \"frequency_penalty\": 0.0,\n                    \"repetition_penalty\": 1.0,\n                    \"temperature\": 0,\n                    \"top_p\": 1,\n                    \"top_k\": -1,\n                    \"ignore_eos\": False,\n                    \"min_new_tokens\": 0,\n                    \"max_new_tokens\": self.config.response_length,\n                    \"skip_special_tokens\": True,\n                    \"spaces_between_special_tokens\": True,\n                }\n            )\n        elif is_validate:\n            request_sampling_params.update(\n                {\n                    \"top_k\": self.config.val_kwargs.top_k,\n                    \"top_p\": self.config.val_kwargs.top_p,\n                    \"temperature\": self.config.val_kwargs.temperature,\n                    \"n\": 1,  # if validate, already repeat in ray_trainer\n                }\n            )\n\n        # Update with any additional kwargs\n        request_sampling_params.update(kwargs)\n\n        while current_turns < self.config.multi_turn.max_assistant_turns:\n            if _req.state == AsyncRolloutRequestStateEnum.PENDING:\n                await self._handle_pending_state(_req)\n                _req.state = AsyncRolloutRequestStateEnum.RUNNING\n            elif _req.state == AsyncRolloutRequestStateEnum.TOOL_CALLING:\n                if _req.messages[-1].tool_calls is not None:\n                    parsed_tool_calls = _req.messages[-1].tool_calls\n                    tool_call_results = await asyncio.gather(\n                        *[\n                            self._tool_map[tool_call.function.name].execute(\n                                _req.request_id,\n                                tool_call.function.arguments,\n                                **_req.tools_kwargs[tool_call.function.name].get(\"execute_kwargs\", {}),\n                            )\n                            for tool_call in parsed_tool_calls\n                        ]\n                    )\n                    _req.add_tool_response_messages(self.processing_class, [resp for resp, _, _ in tool_call_results])\n                    for tool_call, (resp, reward, metrics) in zip(parsed_tool_calls, tool_call_results):\n                        _req.update_metrics(metrics, tool_call.function.name)\n                    if len(_req.input_ids) >= self.config.max_model_len:\n                        finish_reason_type = FinishReasonTypeEnum.STOP\n                        break\n                    _req.state = AsyncRolloutRequestStateEnum.RUNNING\n                else:\n                    raise ValueError(f\"Unexpected tool calling last message state: {_req.messages[-1]}\")\n            elif _req.state == AsyncRolloutRequestStateEnum.RUNNING:\n                # Only continue the conversation if the prompt length is not greater than max_model_len - 1,\n                # since SGLang raises an error when max_new_tokens + 1 is greater to max_model_len (the extra\n                # token accounts for the EOS token).\n                if len(_req.get_generation_prompt_ids(self.processing_class)) + 1 >= self.config.max_model_len:\n                    finish_reason_type = FinishReasonTypeEnum.LENGTH\n                    break\n\n                # Video support is not implemented yet\n                image_data = (\n                    _req.multi_modal_data[\"image\"]\n                    if _req.multi_modal_data and \"image\" in _req.multi_modal_data\n                    else None\n                )\n                video_data = (\n                    _req.multi_modal_data[\"video\"]\n                    if _req.multi_modal_data and \"video\" in _req.multi_modal_data\n                    else None\n                )\n                if video_data:\n                    logger.warning(\n                        \"video support is not implemented yet, current length of video data is %d\", len(video_data)\n                    )\n\n                output = await self._handle_engine_call(_req, request_sampling_params, image_data=image_data)\n                content = output[\"text\"]\n                finish_reason_type = FinishReasonTypeEnum.from_str(output[\"meta_info\"][\"finish_reason\"][\"type\"])\n                current_turns += 1\n                if finish_reason_type == FinishReasonTypeEnum.LENGTH:\n                    _req.add_assistant_message(self.processing_class, content)\n                    break\n                else:\n                    if self._function_call_parser and self._function_call_parser.has_tool_call(content):\n                        finish_reason_type = FinishReasonTypeEnum.TOOL_CALL\n                        _req.state = AsyncRolloutRequestStateEnum.TOOL_CALLING\n                        try:\n                            normed_content, tool_calls = self._function_call_parser.parse_non_stream(content)\n                        except JSONDecodeError:\n                            normed_content = content\n                            tool_calls = []\n                        except AttributeError:\n                            normed_content = content\n                            tool_calls = []\n                        parsed_tool_calls = []\n                        for tool_call in tool_calls:\n                            function, has_decode_error = OpenAIFunctionCallSchema.from_openai_function_parsed_schema(\n                                OpenAIFunctionParsedSchema(\n                                    name=tool_call.name,\n                                    arguments=tool_call.parameters,\n                                )\n                            )\n                            # Drop the tool call if its arguments has decode error\n                            if has_decode_error:\n                                continue\n                            parsed_tool_calls.append(\n                                OpenAIFunctionToolCall(\n                                    id=str(tool_call.tool_index),\n                                    function=function,\n                                )\n                            )\n                        if len(parsed_tool_calls) > 0:\n                            _req.add_assistant_message(\n                                self.processing_class, normed_content, tool_calls=parsed_tool_calls\n                            )\n                        else:\n                            _req.add_assistant_message(self.processing_class, content)\n                            finish_reason_type = FinishReasonTypeEnum.STOP\n                            _req.state = AsyncRolloutRequestStateEnum.COMPLETED\n                            break\n                    else:\n                        _req.add_assistant_message(\n                            self.processing_class,\n                            content,\n                        )\n                        if (\n                            _req.interaction_kwargs\n                            and self.interaction_map\n                            and user_turns < self.config.multi_turn.max_user_turns\n                            and current_turns < self.config.multi_turn.max_assistant_turns\n                        ):\n                            _req.state = AsyncRolloutRequestStateEnum.INTERACTING\n                        else:\n                            break\n            elif _req.state == AsyncRolloutRequestStateEnum.INTERACTING:\n                user_turns += 1\n                messages = [{\"role\": x.role, \"content\": x.content} for x in _req.messages]\n\n                # Get interaction by name from interaction_kwargs\n                interaction_name = _req.interaction_kwargs.get(\n                    \"name\", \"gsm8k\"\n                )  # Default to gsm8k for backward compatibility\n                if interaction_name not in self.interaction_map:\n                    raise ValueError(\n                        f\"Interaction '{interaction_name}' not found in interaction_map. Available interactions: \"\n                        f\"{list(self.interaction_map.keys())}\"\n                    )\n\n                interaction = self.interaction_map[interaction_name]\n                should_terminate_sequence, content, reward, metrics = await interaction.generate_response(\n                    _req.request_id, messages, **_req.interaction_kwargs\n                )\n                user_turn_rewards.append(reward)\n                if should_terminate_sequence:\n                    finish_reason_type = FinishReasonTypeEnum.STOP\n                    _req.state = AsyncRolloutRequestStateEnum.COMPLETED\n                    break\n                else:\n                    _req.add_user_message(self.processing_class, content)\n                    if len(_req.input_ids) >= self.config.max_model_len:\n                        finish_reason_type = FinishReasonTypeEnum.STOP\n                        break\n                    else:\n                        _req.state = AsyncRolloutRequestStateEnum.RUNNING\n\n        if current_turns >= self.config.multi_turn.max_assistant_turns:\n            finish_reason_type = FinishReasonTypeEnum.STOP\n\n        # Calculate the reward for each tool\n        async def calc_reward_and_release_fn(name: str, tool: BaseTool):\n            reward = await tool.calc_reward(_req.request_id, **_req.tools_kwargs[name].get(\"calc_reward_kwargs\", {}))\n            await tool.release(_req.request_id, **_req.tools_kwargs[name].get(\"release_kwargs\", {}))\n            return name, reward\n\n        tool_reward_tasks = []\n        for name in _req.tools_kwargs.keys():\n            tool = self._tool_map[name]\n            tool_reward_tasks.append(calc_reward_and_release_fn(name, tool))\n        tool_reward_scores = await asyncio.gather(*tool_reward_tasks)\n        tool_reward_scores = dict(tool_reward_scores)\n        all_rewards = {**tool_reward_scores, **{\"user_turn_rewards\": user_turn_rewards}}\n        _req.finalize(self.processing_class, all_rewards, finish_reason_type)\n\n        return _req\n\n    async def _handle_engine_call(\n        self, _req: AsyncRolloutRequest, sampling_params: dict, image_data: Optional[list[Any]] = None\n    ) -> dict:\n        generation_prompt_ids = _req.get_generation_prompt_ids(self.processing_class)\n        return await self._handle_engine_generate(generation_prompt_ids, sampling_params, image_data)\n\n    async def _handle_engine_generate(\n        self, generation_prompt_ids: list[int], sampling_params: dict, image_data: Optional[list[Any]] = None\n    ) -> dict:\n        max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1)\n        kwargs = sampling_params.copy()\n        kwargs[\"max_new_tokens\"] = max_new_tokens\n        kwargs[\"n\"] = 1  # group size is supported in preprocess\n        output = await self.inference_engine.async_generate(\n            input_ids=generation_prompt_ids,\n            sampling_params=kwargs,\n            return_logprob=False,\n            image_data=image_data,\n        )\n        return output\n\n    async def _handle_pending_state(self, _req: AsyncRolloutRequest) -> AsyncRolloutRequest:\n        if _req.tool_schemas is not None:\n            tool_creation_coroutines = []\n            for tool_schema in _req.tool_schemas:\n                tool = self._tool_map[tool_schema.function.name]\n                create_kwargs = _req.tools_kwargs[tool.name].get(\"create_kwargs\", {})\n                tool_creation_coroutines.append(tool.create(_req.request_id, **create_kwargs))\n            await asyncio.gather(*tool_creation_coroutines)\n        if _req.interaction_kwargs and self.interaction_map:\n            interaction_kwargs = _req.interaction_kwargs\n            # Get interaction by name from interaction_kwargs\n            interaction_name = interaction_kwargs.get(\"name\", \"gsm8k\")  # Default to gsm8k for backward compatibility\n            if interaction_name not in self.interaction_map:\n                raise ValueError(\n                    f\"Interaction '{interaction_name}' not found in interaction_map. Available interactions: \"\n                    f\"{list(self.interaction_map.keys())}\"\n                )\n\n            interaction = self.interaction_map[interaction_name]\n            await interaction.start_interaction(_req.request_id, **interaction_kwargs)\n\n    @GPUMemoryLogger(role=\"sglang rollout\", logger=logger)\n    @torch.no_grad()\n    def generate_sequences_with_tools(self, prompts: TensorDict, **kwargs) -> TensorDict:\n        logger.warning(\n            \"`generate_sequences_with_tools` is deprecated, please use `generate_sequences(...)`\",\n            DeprecationWarning,\n            stacklevel=2,\n        )\n        return self._req_level_generate_sequences(prompts, **kwargs)\n\n    @GPUMemoryLogger(role=\"sglang rollout\", logger=logger)\n    @torch.no_grad()\n    def _req_level_generate_sequences(self, prompts: TensorDict, **kwargs) -> TensorDict:\n        \"\"\"Generates multi-turn sequences for a batch of prompts.\n        For multi-turn generation, each prompt is processed separately via\n        `_req_level_generate_sequences` for better tool calling control.\n        Note that in multi-turn generation, we repeat the prompts for rollout.n times in ray_trainer.\n        Thus we do not need to repeat the prompts here and set the sampling parameter n to 1.\n        \"\"\"\n        # Async rollout with tools support\n        do_sample = prompts[\"do_sample\"] if \"do_sample\" in prompts else True\n        is_validate = prompts[\"validate\"] if \"validate\" in prompts else False\n        tgt_device = prompts[\"input_ids\"].device\n        if self._tp_rank == 0:\n            req_list = self._preprocess_prompt_to_async_rollout_requests(\n                prompts,\n            )\n            loop = asyncio.get_event_loop()\n            output_req_list = loop.run_until_complete(\n                asyncio.gather(\n                    *[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list],\n                )\n            )\n            sorted_output_req_list = sorted(output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset))\n        else:\n            sorted_output_req_list = None\n\n        dist.barrier()\n        [sorted_output_req_list] = broadcast_pyobj(\n            data=[sorted_output_req_list],\n            rank=self._rank,\n            dist_group=self._device_mesh_cpu[\"tp\"].get_group(),\n            src=self._device_mesh_cpu[\"tp\"].mesh[0].item(),\n            force_cpu_device=False,\n        )\n        # Construct the batch data\n        prompt_ids, response_ids = [], []\n        prompt_attention_mask, response_attention_mask = [], []\n        prompt_position_ids, response_position_ids = [], []\n        prompt_loss_mask, response_loss_mask = [], []\n        messages = []\n        reward_scores = []\n        multi_modal_inputs = []\n\n        for req in sorted_output_req_list:\n            assert req.state == AsyncRolloutRequestStateEnum.COMPLETED, f\"Request {req.request_id} is not completed\"\n            assert (\n                req.input_ids.shape[-1]\n                == req.attention_mask.shape[-1]\n                == req.position_ids.shape[-1]\n                == req.loss_mask.shape[-1]\n            ), f\"\"\"Request {req.request_id} has different length of \n                {req.input_ids.shape[-1]=}, {req.attention_mask.shape[-1]=}, \n                {req.position_ids.shape[-1]=}, {req.loss_mask.shape[-1]=}\"\"\"\n            error_message_lines = [\n                f\"\"\"Request {req.request_id} has input_ids length {req.input_ids.shape[-1]}\n                    greater than max_model_len {self.config.max_model_len}\"\"\",\n                f\"Decoded input_ids: {self.processing_class.decode(req.input_ids.squeeze(0))}\",\n                f\"Decoded prompt_ids: {self.processing_class.decode(req.prompt_ids.squeeze(0))}\",\n                f\"Decoded response_ids: {self.processing_class.decode(req.response_ids.squeeze(0))}\",\n                f\"Messages: {req.messages}\",\n                f\"Max model length: {req.max_model_len}\",\n            ]\n            error_message = \"\\n\".join(error_message_lines)\n            assert req.input_ids.shape[-1] <= self.config.max_model_len, error_message\n\n            prompt_ids.append(req.prompt_ids.to(tgt_device).squeeze(0))\n            response_ids.append(req.response_ids.to(tgt_device).squeeze(0))\n            if req.response_ids.shape[-1] > self.config.response_length:\n                logger.warning(\n                    f\"\"\"{req.request_id=} has response_ids length {req.response_ids.shape[-1]} \n                    greater than max_response_len {self.config.response_length},\\n{req=}\"\"\"\n                )\n            prompt_attention_mask.append(req.prompt_attention_mask.to(tgt_device).squeeze(0))\n            response_attention_mask.append(req.response_attention_mask.to(tgt_device).squeeze(0))\n            prompt_position_ids.append(req.prompt_position_ids.to(tgt_device).squeeze(0))\n            response_position_ids.append(req.response_position_ids.to(tgt_device).squeeze(0))\n            prompt_loss_mask.append(req.prompt_loss_mask.to(tgt_device).squeeze(0))\n            response_loss_mask.append(req.response_loss_mask.to(tgt_device).squeeze(0))\n            messages.append({\"messages\": req.messages})\n            reward_scores.append(req.reward_scores)\n            multi_modal_inputs.append(req.multi_modal_inputs)\n\n        prompt_ids = pad_sequence(\n            prompt_ids,\n            batch_first=True,\n            padding_value=self.pad_token_id,\n            padding_side=\"left\",\n        )\n        if prompt_ids.shape[-1] < self.config.prompt_length:\n            prompt_ids = pad_sequence_to_length(prompt_ids, self.config.prompt_length, self.pad_token_id, left_pad=True)\n        response_ids = pad_sequence(response_ids, batch_first=True, padding_value=self.pad_token_id)\n        if response_ids.shape[-1] < self.config.response_length:\n            response_ids = pad_sequence_to_length(response_ids, self.config.response_length, self.pad_token_id)\n        prompt_attention_mask = pad_sequence(\n            prompt_attention_mask,\n            batch_first=True,\n            padding_value=0,\n            padding_side=\"left\",\n        )\n        if prompt_attention_mask.shape[-1] < self.config.prompt_length:\n            prompt_attention_mask = pad_sequence_to_length(\n                prompt_attention_mask, self.config.prompt_length, 0, left_pad=True\n            )\n        response_attention_mask = pad_sequence(response_attention_mask, batch_first=True, padding_value=0)\n        if response_attention_mask.shape[-1] < self.config.response_length:\n            response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0)\n\n        # padding prompt_position_ids\n        if prompt_position_ids[0].dim() == 2:\n            # if prompt_position_ids is a 2D tensor\n            # e.g. from qwen2vl, prompt_position_ids.shape = (3, seq_len)\n            transposed_prompt_position_ids = [p.transpose(0, 1) for p in prompt_position_ids]\n            prompt_position_ids = pad_sequence(\n                transposed_prompt_position_ids, batch_first=True, padding_value=0, padding_side=\"left\"\n            )\n            prompt_position_ids = prompt_position_ids.transpose(1, 2)\n        else:\n            prompt_position_ids = pad_sequence(\n                prompt_position_ids, batch_first=True, padding_value=0, padding_side=\"left\"\n            )\n        if prompt_position_ids.shape[-1] < self.config.prompt_length:\n            prompt_position_ids = pad_sequence_to_length(\n                prompt_position_ids, self.config.prompt_length, 0, left_pad=True\n            )\n\n        # padding response_position_ids\n        if response_position_ids[0].dim() == 2:\n            # if response_position_ids is a 2D tensor\n            # e.g. from qwen2vl, response_position_ids.shape = (3, seq_len)\n            transposed_response_position_ids = [p.transpose(0, 1) for p in response_position_ids]\n            response_position_ids = pad_sequence(\n                transposed_response_position_ids, batch_first=True, padding_value=0, padding_side=\"left\"\n            )\n            response_position_ids = response_position_ids.transpose(1, 2)\n        else:\n            response_position_ids = pad_sequence(response_position_ids, batch_first=True, padding_value=0)\n        if response_position_ids.shape[-1] < self.config.response_length:\n            response_position_ids = pad_sequence_to_length(response_position_ids, self.config.response_length, 0)\n\n        prompt_loss_mask = pad_sequence(prompt_loss_mask, batch_first=True, padding_value=0, padding_side=\"left\")\n        if prompt_loss_mask.shape[1] < self.config.prompt_length:\n            prompt_loss_mask = pad_sequence_to_length(prompt_loss_mask, self.config.prompt_length, 0, left_pad=True)\n        response_loss_mask = pad_sequence(response_loss_mask, batch_first=True, padding_value=0)\n        if response_loss_mask.shape[1] < self.config.response_length:\n            response_loss_mask = pad_sequence_to_length(response_loss_mask, self.config.response_length, 0)\n\n        input_ids = torch.cat((prompt_ids, response_ids), dim=-1)\n        attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)\n        position_ids = torch.cat((prompt_position_ids, response_position_ids), dim=-1)\n\n        # Construct the batch data\n        batch = TensorDict(\n            {\n                \"prompts\": prompt_ids,\n                \"responses\": response_ids,\n                \"response_mask\": response_loss_mask,\n                \"input_ids\": input_ids,  # here input_ids become the whole sentences\n                \"attention_mask\": attention_mask,\n                \"position_ids\": position_ids,\n            },\n            batch_size=len(sorted_output_req_list),\n        )\n\n        # free cache engine\n        if self.inference_engine is not None and self._tp_rank == 0:\n            loop = asyncio.get_event_loop()\n            loop.run_until_complete(self.inference_engine.flush_cache())\n        prompts.update(batch)\n        prompts[\"messages\"] = np.array(messages)\n        prompts[\"reward_scores\"] = np.array(reward_scores)\n        prompts[\"multi_modal_inputs\"] = multi_modal_inputs\n        return prompts\n\n    def _preprocess_prompt_to_async_rollout_requests(self, prompts: TensorDict, n: int = 1) -> list[AsyncRolloutRequest]:\n        assert \"raw_prompt\" in prompts, (\n            \"need data.return_raw_chat=True, due to no official way do parse_messages\"\n        )\n        logger.info(\n            \"n is deprecated for SGLang rollout since ray ppo trainer will repeat the prompts for rollout.n times\"\n        )\n        req_list = []\n        \n        multi_modal_data_list = prompts[\"multi_modal_data\"] if \"multi_modal_data\" in prompts else  [None] * len(prompts[\"raw_prompt\"])\n\n        for data_idx, (raw_prompt, multi_modal_data) in enumerate(\n            zip(prompts[\"raw_prompt\"], multi_modal_data_list)\n        ):\n            if self._tool_schemas:\n                _tools_kwargs = prompts[\"tools_kwargs\"][data_idx]\n                _tool_schemas = [self._tool_map[k].get_openai_tool_schema() for k in _tools_kwargs.keys()]\n                _input_ids = None\n                _attention_mask = None\n            else:\n                _input_ids = _pre_process_inputs(self.pad_token_id, prompts[\"input_ids\"][data_idx])\n                _attention_mask = _pre_process_inputs(0, prompts[\"attention_mask\"][data_idx])\n                _tools_kwargs = {}\n                _tool_schemas = None\n\n            if self.interaction_map:\n                _interaction_kwargs = prompts[\"interaction_kwargs\"][data_idx]\n            else:\n                _interaction_kwargs = {}\n\n            req = AsyncRolloutRequest(\n                batch_data_id=data_idx,\n                rollout_offset=0,\n                request_id=str(uuid4()),\n                state=AsyncRolloutRequestStateEnum.PENDING,\n                messages=raw_prompt.tolist(),\n                multi_modal_data=multi_modal_data,\n                tool_schemas=_tool_schemas,\n                tools_kwargs=_tools_kwargs,\n                interaction_kwargs=_interaction_kwargs,\n                input_ids=_input_ids,\n                response_ids=None,\n                attention_mask=_attention_mask,\n                response_attention_mask=None,\n                response_position_ids=None,\n                response_loss_mask=None,\n                reward_scores={},\n                max_prompt_len=self.config.prompt_length,\n                max_response_len=self.config.response_length,\n                max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length),\n                use_inference_chat_template=self.config.multi_turn.use_inference_chat_template,\n                tokenization_sanity_check_mode=self.config.multi_turn.tokenization_sanity_check_mode,\n                processing_class=self.processing_class,\n            )\n            error_message = f\"\"\"Request {req.request_id} has mismatched lengths: \n            input_ids={req.input_ids.shape[-1]}, \n            attention_mask={req.attention_mask.shape[-1]}, \n            position_ids={req.position_ids.shape[-1]}, \n            loss_mask={req.loss_mask.shape[-1]}\"\"\"\n            assert (\n                req.input_ids.shape[-1]\n                == req.attention_mask.shape[-1]\n                == req.position_ids.shape[-1]\n                == req.loss_mask.shape[-1]\n            ), error_message\n            req_list.append(req)\n\n        return req_list\n\n    async def chat_completion(self, json_request):\n        assert self._tp_rank == 0, \"only called in tp rank 0\"\n        _input_ids = None\n        _attention_mask = None\n        _position_ids = None\n        _tool_schemas = []\n        _tools_kwargs = {}\n\n        req = AsyncRolloutRequest(\n            request_id=str(uuid4()),\n            state=AsyncRolloutRequestStateEnum.PENDING,\n            messages=[Message.model_validate(msg) for msg in json_request[\"messages\"]],\n            tool_schemas=_tool_schemas,\n            tools_kwargs=_tools_kwargs,\n            input_ids=_input_ids,\n            prompt_ids=_input_ids,\n            response_ids=None,\n            attention_mask=_attention_mask,\n            prompt_attention_mask=_attention_mask,\n            response_attention_mask=None,\n            position_ids=_position_ids,\n            prompt_position_ids=_position_ids,\n            response_position_ids=None,\n            loss_mask=None,\n            prompt_loss_mask=None,\n            response_loss_mask=None,\n            reward_scores={},\n            max_prompt_len=self.config.prompt_length,\n            max_response_len=self.config.response_length,\n            max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length),\n            use_inference_chat_template=self.config.multi_turn.use_inference_chat_template,\n            tokenization_sanity_check_mode=self.config.multi_turn.tokenization_sanity_check_mode,\n            processing_class=self.processing_class,\n        )\n\n        # json_request already contains sampling_params\n        # Filter only valid SamplingParams arguments\n        valid_sampling_params = {}\n        temp_sampling_params = SamplingParams()  # Create temporary instance to check valid attributes\n        for k, v in json_request.items():\n            if k not in [\"messages\", \"model\", \"tools\"] and hasattr(temp_sampling_params, k):\n                valid_sampling_params[k] = v\n        output = await self._handle_engine_call(req, valid_sampling_params)\n        # it can be Dict or AsyncIterator[Dict]\n        if isinstance(output, dict):\n            outputs = [output]\n        else:\n            outputs = output\n\n        # build openai chat completion format\n        choices = []\n        id = None\n        for i, content in enumerate(outputs):\n            choices.append(\n                {\n                    \"index\": i,\n                    \"message\": {\n                        \"role\": \"assistant\",\n                        \"content\": content[\"text\"],\n                    },\n                    \"finish_reason\": content[\"meta_info\"][\"finish_reason\"][\"type\"],\n                }\n            )\n            id = content[\"meta_info\"][\"id\"]\n\n        return {\n            \"id\": \"chatcmpl-\" + id,\n            \"object\": \"chat.completion\",\n            \"created\": int(time.time()),\n            \"model\": json_request.get(\"model\", \"sglang_model\"),\n            \"choices\": choices,\n        }\n\n        # this function is left for uniform train-inference resharding\n\n    async def generate(\n        self, prompt_ids: torch.Tensor, sampling_params: dict[str, Any], request_id: str\n    ) -> torch.Tensor:\n        request_sampling_params = self.sampling_params.copy()\n        request_sampling_params.update(sampling_params)\n        output = await self._handle_engine_generate(prompt_ids, request_sampling_params)\n        return output[\"text\"] if self.config.mode == \"sync\" else output[\"output_ids\"]\n    \n    async def wake_up(self):\n        if not self.is_sleep:\n            return\n        await self.sharding_manager.wake_up()  # pylint: disable=C2801\n        self.is_sleep = False\n\n    # this function is left for uniform train-inference resharding\n    async def sleep(self):\n        if self.is_sleep:\n            return\n        await self.sharding_manager.sleep()\n        self.is_sleep = True\n\n    # used for async mode\n    \n    def _init_zeromq(self) -> str:\n        tensor_parallel_size = self.config.tensor_model_parallel_size\n\n        # single node: ipc, multi nodes: tcp\n        local_world_size = int(os.environ[\"RAY_LOCAL_WORLD_SIZE\"])\n        socket_type = \"ipc\" if tensor_parallel_size <= local_world_size else \"tcp\"\n\n        # File lock to prevent multiple workers listen to same port\n        with FileLock(\"/tmp/siirl_vllm_zmq.lock\"):\n            if socket_type == \"ipc\":\n                pid = os.getpid()\n                address = f\"ipc:///tmp/siirl_vllm_zmq_{pid}.ipc\"\n            else:\n                ip, port = self._get_free_port()\n                address = f\"tcp://{ip}:{port}\"\n            context = zmq.Context()\n            self.socket = context.socket(zmq.REP)\n            self.socket.bind(address)\n\n        self.loop_thread = threading.Thread(target=self._loop_forever)\n        self.loop_thread.start()\n        return address\n\n    def _get_free_port(self):\n        ip = ray._private.services.get_node_ip_address()\n        with socket.socket() as sock:\n            sock.bind((\"\", 0))\n            port = sock.getsockname()[1]\n        return ip, port\n\n    def _loop_forever(self):\n        while True:\n            message = self.socket.recv()\n            method, args, kwargs = pickle.loads(message)\n            result = self.execute_method(method, *args, **kwargs)\n            self.socket.send(pickle.dumps(result))\n    def get_zeromq_address(self):\n        return self.address\n     \n    def execute_method(self, method: Union[str, bytes], *args, **kwargs):\n        if method == \"generate\":\n            loop = ensure_event_loop()\n            return loop.run_until_complete(self.generate(*args, **kwargs))\n        elif method == \"sleep\":\n            loop = ensure_event_loop()\n            return loop.run_until_complete(self.sleep())\n        elif method == \"wake_up\":\n            loop = ensure_event_loop()\n            return loop.run_until_complete(self.wake_up())\n        else:\n            assert False, f\"{method} has not implement\"\n    def get_device_mesh(self):\n        return self._device_mesh_cpu\n            \ndef ensure_event_loop():\n    try:\n        return asyncio.get_running_loop()\n    except RuntimeError:\n        loop = asyncio.new_event_loop()\n        asyncio.set_event_loop(loop)\n        return loop\n"
  },
  {
    "path": "siirl/engine/rollout/sglang_rollout/utils.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pickle\nfrom typing import Any, List, Optional\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\n\nfrom siirl.utils.extras.device import get_device_name\n\n\ndef broadcast_pyobj(\n    data: List[Any],\n    rank: int,\n    dist_group: Optional[torch.distributed.ProcessGroup] = None,\n    src: int = 0,\n    force_cpu_device: bool = False,\n):\n    \"\"\"from https://github.com/sgl-project/sglang/blob/844e2f227ab0cce6ef818a719170ce37b9eb1e1b/python/sglang/srt/utils.py#L905\n\n    Broadcast inputs from src rank to all other ranks with torch.dist backend.\n    The `rank` here refer to the source rank on global process group (regardless\n    of dist_group argument).\n    \"\"\"\n    device = torch.device(get_device_name() if not force_cpu_device else \"cpu\")\n\n    if rank == src:\n        if len(data) == 0:\n            tensor_size = torch.tensor([0], dtype=torch.long, device=device)\n            dist.broadcast(tensor_size, src=src, group=dist_group)\n        else:\n            serialized_data = pickle.dumps(data)\n            size = len(serialized_data)\n\n            tensor_data = torch.ByteTensor(np.frombuffer(serialized_data, dtype=np.uint8)).to(device)\n            tensor_size = torch.tensor([size], dtype=torch.long, device=device)\n\n            dist.broadcast(tensor_size, src=src, group=dist_group)\n            dist.broadcast(tensor_data, src=src, group=dist_group)\n        return data\n    else:\n        tensor_size = torch.tensor([0], dtype=torch.long, device=device)\n        dist.broadcast(tensor_size, src=src, group=dist_group)\n        size = tensor_size.item()\n\n        if size == 0:\n            return []\n\n        tensor_data = torch.empty(size, dtype=torch.uint8, device=device)\n        dist.broadcast(tensor_data, src=src, group=dist_group)\n\n        serialized_data = bytes(tensor_data.cpu().numpy())\n        data = pickle.loads(serialized_data)\n        return data\n"
  },
  {
    "path": "siirl/engine/rollout/vllm_rollout/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\nfrom importlib.metadata import PackageNotFoundError, version\n\nfrom packaging.version import Version\n\n\ndef get_version(pkg):\n    try:\n        return version(pkg)\n    except PackageNotFoundError:\n        return None\n\n\nvllm_package_name = \"vllm\"\nvllm_package_version = get_version(vllm_package_name)\nif vllm_package_version is None:\n    raise PackageNotFoundError(\"To use vllm rollout, please ensure the 'vllm' package is properly installed.\")\n\n###\n# package_version = get_version(package_name)\n# [SUPPORT AMD:]\n# Do not call any torch.cuda* API here, or ray actor creation import class will fail.\nif \"ROCM_PATH\" in os.environ:\n    import re\n\n    match = re.match(r\"(\\d+\\.\\d+\\.?\\d*)\", vllm_package_version)\n    if match:\n        vllm_package_version = match.group(1)\n    else:\n        raise ValueError(f\"Warning: Could not parse version format: {vllm_package_version}\")\n###\n\nvllm_mode = \"spmd\"\nfrom .vllm_rollout_spmd import vLLMAsyncRollout, vLLMRollout  # noqa: F401\n"
  },
  {
    "path": "siirl/engine/rollout/vllm_rollout/vllm_async_server.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport logging\nimport os\nimport pickle\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport ray\nimport zmq\nfrom omegaconf import DictConfig\nfrom starlette.requests import Request\nfrom starlette.responses import JSONResponse, StreamingResponse\nfrom vllm import SamplingParams\nfrom vllm.engine.arg_utils import AsyncEngineArgs\nfrom vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse\nfrom vllm.inputs import TokensPrompt\nfrom vllm.outputs import RequestOutput\nfrom vllm.v1.engine.async_llm import AsyncLLM\nfrom vllm.v1.executor.abstract import Executor\nfrom vllm.worker.worker_base import WorkerWrapperBase\nfrom siirl.utils.debug import GPUMemoryLogger\nfrom siirl.utils.extras.fs import copy_to_local\nfrom siirl.engine.rollout.async_server import AsyncServerBase\nfrom tensordict import TensorDict\nfrom siirl.params.model_args import ActorRolloutRefArguments\nimport torch\nlogger = logging.getLogger(__file__)\n\n\nclass ExternalZeroMQDistributedExecutor(Executor):\n    \"\"\"An executor that engines are launched by external ray actors.\"\"\"\n\n    uses_ray: bool = False\n\n    def _init_executor(self) -> None:\n        addresses = os.environ[\"SIIRL_VLLM_ZMQ_ADDRESSES\"].split(\",\")\n        self.context = zmq.Context()\n        self.sockets = []\n        for address in addresses:\n            socket = self.context.socket(zmq.REQ)\n            socket.connect(address)\n            self.sockets.append(socket)\n\n        kwargs = dict(\n            vllm_config=self.vllm_config,\n            local_rank=None,\n            rank=None,\n            distributed_init_method=\"env://\",\n            is_driver_worker=True,\n        )\n        self.collective_rpc(\"init_worker\", args=([kwargs],))\n        self.collective_rpc(\"init_device\")\n        self.collective_rpc(\"load_model\")\n        \n\n    def collective_rpc(\n        self,\n        method: Union[str, Callable],\n        timeout: Optional[float] = None,\n        args: Tuple = (),\n        kwargs: Optional[Dict[str, Any]] = None,\n    ) -> List[Any]:\n        if isinstance(method, str):\n            sent_method = method\n        else:\n            sent_method = pickle.dumps(method)\n        del method\n\n        message = pickle.dumps((sent_method, args, kwargs or {}))\n        for socket in self.sockets:\n            socket.send(message, zmq.DONTWAIT)\n\n        outputs = []\n        for socket in self.sockets:\n            outputs.append(pickle.loads(socket.recv()))\n        return outputs\n\n    def check_health(self):\n        return\n\nclass AsyncvLLMServer(AsyncServerBase):\n    \"\"\"\n    AsyncvLLMServer is a wrapper for AsyncLLM, it uses ExternalRayDistributedExecutor to launch engines\n    in hybrid rollout workers, i.e AsyncActorRolloutRefWorker.\n\n    AsyncvLLMServer works as follows:\n    1. Start FastAPI server first.\n    2. Initialize AsyncLLM with ExternalRayDistributedExecutor.\n    3. AsyncLLM spawn EngineCore in subprocess.\n    4. EngineCore initialize ExternalRayDistributedExecutor.\n    5. ExternalRayDistributedExecutor lookup its corresponding actors by name.\n    6. ExternalRayDistributedExecutor init executor: init_worker, init_device, load_model.\n\n    For vLLM AsyncLLM design, see: https://github.com/vllm-project/vllm/pull/9826\n    \"\"\"\n\n    def __init__(self, config: ActorRolloutRefArguments,  spmd_engine: Any, zmq_addresses:List):\n        \"\"\"\n        Args:\n            config: DictConfig.\n            wg_prefix: str, worker group prefix, used to lookup actors.\n            engine: Any, used in sglang ,vllm not need\n        \"\"\"\n        super().__init__()\n\n        self.config = config\n        self.spmd_engine = spmd_engine\n        self.zmq_addresses = zmq_addresses\n        self.engine: AsyncLLM = None\n        \n    def init_engine(self):\n        \"\"\"Init vLLM AsyncLLM engine.\"\"\"\n        config = self.config\n        model_path = config.model.path\n        local_path = copy_to_local(model_path)\n        trust_remote_code = config.rollout.trust_remote_code\n        config = config.rollout\n\n        tensor_parallel_size = config.tensor_model_parallel_size\n        max_num_batched_tokens = config.max_num_batched_tokens\n        max_model_len = config.max_model_len if config.max_model_len else config.prompt_length + config.response_length\n        self.max_model_len = int(max_model_len)\n\n        # Override default generation config from hugging face model config,\n        # user can still override them by passing kwargs in each request.\n        kwargs = dict(\n            n=1,\n            logprobs=0,\n            repetition_penalty=1.0,\n            max_new_tokens=config.response_length,\n        )\n        # for k in config.keys():\n            # if hasattr(SamplingParams(), str(k)):\n        #         kwargs[k] = config.get(k)\n        kwargs['temperature'] = config.temperature\n        kwargs['top_k'] = config.top_k\n        kwargs['top_p'] = config.top_p\n\n        # only support zmq_executor\n        distributed_executor_backend = ExternalZeroMQDistributedExecutor\n        \n        engine_args = AsyncEngineArgs(\n            model=local_path,\n            enable_sleep_mode=config.free_cache_engine,\n            override_generation_config=kwargs,\n            tensor_parallel_size=tensor_parallel_size,\n            distributed_executor_backend=distributed_executor_backend,\n            dtype=config.dtype,\n            enforce_eager=config.enforce_eager,\n            gpu_memory_utilization=config.gpu_memory_utilization,\n            disable_custom_all_reduce=True,\n            skip_tokenizer_init=False,\n            max_model_len=self.max_model_len,\n            load_format=\"auto\",\n            disable_log_stats=config.disable_log_stats,\n            max_num_batched_tokens=max_num_batched_tokens,\n            enable_chunked_prefill=config.enable_chunked_prefill,\n            enable_prefix_caching=True,\n            trust_remote_code=trust_remote_code,\n            seed=config.seed,\n        )\n        # init async llm engine\n        vllm_config = self._create_engine_config(engine_args, self.zmq_addresses)\n        \n        self.engine = AsyncLLM.from_vllm_config(vllm_config)\n        \n\n    def _create_engine_config(self, engine_args: AsyncEngineArgs, zmq_addresses:List):\n        # wg_prefix = os.environ['WG_PREFIX']\n        # local_world_size = os.environ['RAY_LOCAL_WORLD_SIZE']\n        # local_rank = os.environ['RAY_LOCAL_RANK']\n        vllm_config = engine_args.create_engine_config()\n        # SIIRL_VLLM_ZMQ_ADDRESSES\n        if engine_args.distributed_executor_backend == ExternalZeroMQDistributedExecutor:\n            os.environ[\"SIIRL_VLLM_ZMQ_ADDRESSES\"] = \",\".join(zmq_addresses)\n\n        return vllm_config\n    \n    \n    @GPUMemoryLogger(role=\"vllm rollout spmd\", logger=logger)\n    def generate_sequences(self, prompts: TensorDict, **kwargs) -> TensorDict:\n        assert False\n        \n    async def chat_completion(self, raw_request: Request):\n        \"\"\"OpenAI-compatible HTTP endpoint.\n\n        API reference: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html\n        \"\"\"\n        request_json = await raw_request.json()\n        request = ChatCompletionRequest(**request_json)\n        generator = await self.openai_serving_chat.create_chat_completion(request, raw_request)\n\n        if isinstance(generator, ErrorResponse):\n            return JSONResponse(content=generator.model_dump(), status_code=generator.code)\n        if request.stream:\n            return StreamingResponse(content=generator, media_type=\"text/event-stream\")\n        else:\n            assert isinstance(generator, ChatCompletionResponse)\n            return JSONResponse(content=generator.model_dump())\n\n    async def generate(self, prompt_ids: List[int], sampling_params: Dict[str, Any], request_id: str) -> List[int]:\n        max_tokens = self.max_model_len - len(prompt_ids)\n        sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params)\n        prompt = TokensPrompt(prompt_token_ids=prompt_ids)\n        generator = self.engine.generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id)\n\n        # Get final response\n        final_res: Optional[RequestOutput] = None\n        async for output in generator:\n            final_res = output\n        assert final_res is not None\n\n        return final_res.outputs[0].token_ids\n\n    async def wake_up(self):\n        if self.config.rollout.free_cache_engine:\n            await self.engine.wake_up()\n\n    async def sleep(self):\n        # TODO: https://github.com/vllm-project/vllm/issues/17103\n        await self.engine.reset_prefix_cache()\n        if self.config.rollout.free_cache_engine:\n            await self.engine.sleep()\n"
  },
  {
    "path": "siirl/engine/rollout/vllm_rollout/vllm_rollout_spmd.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe vllm_rollout that can be applied in different backend\nWhen working with FSDP:\n- Use DTensor weight loader (recommended) or HF weight loader\n- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM\nWhen working with Megatron:\n- Use Megatron weight loader\n- During training, only the current pp stage holds the parameters\n- Before inference, broadcast the parameters of the current pp rank\n  to all other pp ranks (all pp ranks holds all the parameters)\n- Bind the parameters to the inference engine\n- Do inference in tp. pp is treated as additional dp\n- After inference, all the parameters that doesn't belong to this pp rank is freed.\n\"\"\"\n\n\n\nimport pickle\nimport socket\nimport threading\nimport ray\nimport zmq\nimport csv\nimport os\nimport time\n\nimport numpy as np\nimport torch\nimport torch.distributed\n\nfrom contextlib import contextmanager\nfrom copy import deepcopy\nfrom datetime import datetime\nfrom importlib.metadata import version\nfrom packaging import version as vs\nfrom typing import Any, Dict, List, Union\nfrom zoneinfo import ZoneInfo\n\n\nfrom filelock import FileLock\nfrom omegaconf import DictConfig, OmegaConf\nfrom types import MethodType\n\nfrom loguru import logger\nfrom tensordict import TensorDict\n# Handle different tensordict versions - NonTensorData location varies\ntry:\n    from tensordict import NonTensorData\nexcept ImportError:\n    from tensordict.tensorclass import NonTensorData\nfrom vllm import LLM, SamplingParams\nfrom vllm.distributed import parallel_state as vllm_ps\nfrom vllm.lora.request import LoRARequest\nfrom vllm.worker.worker_base import WorkerWrapperBase\nfrom vllm.model_executor.sampling_metadata import SamplingMetadata\n\nfrom siirl.utils.debug import GPUMemoryLogger\nfrom siirl.utils.model_utils.torch_functional import get_response_mask, pad_2d_list_to_length\nfrom siirl.params import RolloutArguments\nfrom siirl.engine.rollout.base import BaseRollout\nfrom siirl.utils.extras.device import is_cuda_available, device_synchronize\nfrom siirl.utils.extras.device import get_device_id\n# TODO\n# 1. support pp in vllm\n# 2. passing tokenizer is not necessary? no encoding/decoding is happending here\n# 3. simplify init logics\n\n\n# NOTE(sgm): add for siirl. We can optimize it by making the dataloader yield List[int] without padding.\ndef _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]:\n    # remove the left padding in the prompt token_id\n    # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id\n    # is not None else self.llm_engine.tokenizer.eos_token_id\n    non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]\n    token_ids = prompt_token_ids[non_pad_index:].tolist()\n    return token_ids\n\n\ndef _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]:\n    if isinstance(value, torch.Tensor):\n        return value.repeat_interleave(repeats, dim=0)\n    else:\n        return np.repeat(value, repeats, axis=0)\n\n\nclass vLLMRollout(BaseRollout):\n    def __init__(self, model_path: str, config: RolloutArguments, tokenizer, model_hf_config, **kwargs):\n        \"\"\"A vLLM rollout. It requires the module is supported by the vllm.\n\n        Args:\n            module: module here follows huggingface APIs\n            config: RolloutArguments\n            tokenizer: the task/model tokenizer\n            model_hf_config: the huggingface config to initiallize the generating model in vllm\n            **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group\n        \"\"\"\n        super().__init__()\n        self.config = config\n\n        # micro_batch_size for mini-batch inference\n        self.micro_batch_size = config.micro_batch_size if config.micro_batch_size else 0\n        if self.micro_batch_size > 0:\n            logger.info(f\"Mini-batch inference is enabled with micro_batch_size: {self.micro_batch_size}\")\n\n        assert not (not config.enforce_eager and config.free_cache_engine), \"disable CUDA graph (enforce_eager = False) if free cache engine\"\n\n        tensor_parallel_size = self.config.tensor_model_parallel_size\n        assert tensor_parallel_size <= torch.distributed.get_world_size(), \"tensor parallel size should be less than or equal to the world size\"\n        max_num_batched_tokens = self.config.max_num_batched_tokens\n\n        if kwargs.get(\"train_tp\") is not None:\n            # deployed with megatron\n            os.environ[\"CUDA_TIMER_STREAM_KAFKA_ENABLE\"] = \"0\"\n            os.environ[\"MEGATRON_IMPORT_TIMERS\"] = \"0\"\n            vllm_ps.initialize_model_parallel(tensor_model_parallel_size=tensor_parallel_size)\n\n        rope_scaling_config = getattr(model_hf_config, \"rope_scaling\", None)\n        if not rope_scaling_config:\n            max_position_embeddings = None\n            if hasattr(model_hf_config, \"max_position_embeddings\"):\n                max_position_embeddings = model_hf_config.max_position_embeddings\n            elif hasattr(model_hf_config, \"llm_config\") and hasattr(model_hf_config.llm_config, \"max_position_embeddings\"):\n                max_position_embeddings = model_hf_config.llm_config.max_position_embeddings\n            elif hasattr(model_hf_config, \"text_config\") and hasattr(model_hf_config.text_config, \"max_position_embeddings\"):\n                max_position_embeddings = model_hf_config.text_config.max_position_embeddings\n            if max_position_embeddings is None:\n                raise ValueError(\"max_position_embeddings not found in model_hf_config\")\n\n            assert max_position_embeddings >= config.prompt_length + config.response_length, \"model context length should be greater than total sequence length\"\n\n        max_model_len = int(config.max_model_len or config.prompt_length + config.response_length)\n\n        if max_num_batched_tokens < max_model_len and self.config.enable_chunked_prefill:\n            raise ValueError(\n                \"Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \\\n                             please increase max_num_batched_tokens or disable chunked prefill\"\n            )\n\n        trust_remote_code = kwargs.get(\"trust_remote_code\", False)\n        load_format = \"dummy\" if config.load_format.startswith(\"dummy\") else config.load_format\n\n        lora_kwargs = kwargs.pop(\"lora_kwargs\", {})\n        self.lora_kwargs = lora_kwargs\n        # copy it to avoid secretly modifying the engine config\n        engine_kwargs = deepcopy(config.engine_kwargs.vllm)\n        # For each vLLM engine parameter,\n        # - `None` means not setting it, so we pop it, and leave it to vLLM default value\n        #    (which can vary across different vLLM versions);\n        # - Otherwise it's the desired value we want to explicitly set.\n        engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None}\n        if config.limit_images:  # support for multi-image data\n            engine_kwargs[\"limit_mm_per_prompt\"] = {\"image\": config.limit_images}\n\n        vllm_version = version(\"vllm\")\n        if vs.parse(vllm_version) >= vs.parse(\"0.9.0\") and is_cuda_available:\n            logger.info(f\"Add environment variable for vLLM version {vllm_version}\")\n            # This environment variable is mandatory due to an issue where PyTorch 2.7.0\n            #   causes flashinfer to fail with the error `FlashInfer requires sm75+`.\n            # This can be removed once a later version of PyTorch or flashinfer fixes the issue.\n            #   Reference: https://github.com/flashinfer-ai/flashinfer/issues/1101\n            cap = torch.cuda.get_device_capability(torch.cuda.current_device())\n            os.environ[\"VLLM_DISABLE_COMPILE_CACHE\"] = \"1\"\n            os.environ[\"TORCH_CUDA_ARCH_LIST\"] = f\"{cap[0]}.{cap[1]}+PTX\"\n\n        self.inference_engine = LLM(\n            model=model_path,\n            enable_sleep_mode=True,\n            tensor_parallel_size=tensor_parallel_size,\n            distributed_executor_backend=\"external_launcher\",\n            dtype=config.dtype,\n            enforce_eager=config.enforce_eager,\n            gpu_memory_utilization=config.gpu_memory_utilization,\n            disable_custom_all_reduce=True,\n            skip_tokenizer_init=False,\n            max_model_len=max_model_len,\n            load_format=load_format,\n            disable_log_stats=config.disable_log_stats,\n            max_num_batched_tokens=max_num_batched_tokens,\n            enable_chunked_prefill=config.enable_chunked_prefill,\n            enable_prefix_caching=True,\n            trust_remote_code=trust_remote_code,\n            seed=config.seed,  # Use None for random seed to avoid identical outputs for repeated inputs\n            **lora_kwargs,\n            **engine_kwargs,\n        )\n\n\n        # Offload vllm model to reduce peak memory usage\n        self.inference_engine.sleep(level=1)\n\n        kwargs = dict(\n            n=1,\n            logprobs=0,  # can be set to 0 and let actor to recompute\n            max_tokens=config.response_length,\n        )\n\n        # # we may detokenize the result all together later\n        kwargs[\"detokenize\"] = False\n\n        # supporting adding any sampling params from the config file\n        dictConfig = config.to_dict()\n        for k in dictConfig.keys():\n            if hasattr(SamplingParams(), str(k)) and k != \"seed\":\n                kwargs[k] = dictConfig.get(k)\n\n        kwargs[\"n\"] = 1  # already repeat in ray_trainer\n\n        logger.info(f\"kwargs: {kwargs}\")\n        self.sampling_params = SamplingParams(**kwargs)\n\n        if \"internvl\" in model_hf_config.model_type:\n            stop_tokens = [\"<|endoftext|>\", \"<|im_start|>\", \"<|im_end|>\", \"<|end|>\"]\n            stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]\n            self.sampling_params.stop_token_ids = stop_token_ids\n            if \"internlm2\" in model_hf_config.llm_config.model_type:\n                logger.info(\"Set vllm tokenizer for internlm2\")\n                self.inference_engine.set_tokenizer(tokenizer)\n\n        self.pad_token_id = tokenizer.pad_token_id\n\n        self.enbale_perf = os.environ.get(\"SIIRL_ENABLE_PERF\", \"0\") == \"1\"\n        if self.enbale_perf:\n            self.perf_step = 1\n            world_size = torch.distributed.get_world_size()\n            model_name = os.path.basename(os.path.normpath(model_path))\n            ts = datetime.now(tz=ZoneInfo(\"Asia/Shanghai\")).strftime(\"%Y-%m-%d-%H-%M-%S\")\n            self.perf_log_path = os.path.join(\"performance_logs\", model_name, ts)\n            os.makedirs(self.perf_log_path, exist_ok=True)\n            self.perf_log_file = os.path.join(self.perf_log_path, f\"rollout_world_size_{world_size}.csv\")\n            with open(self.perf_log_file, \"w\", newline=\"\") as f:\n                writer = csv.writer(f)\n                writer.writerow([\"step\", \"rank\", \"timestamp\", \"batch_size\", \"inference_latency_s\", \"min_prompt_len\", \"max_prompt_len\", \"avg_prompt_len\", \"min_gen_len\", \"max_gen_len\", \"avg_gen_len\"])\n\n    @contextmanager\n    def update_sampling_params(self, **kwargs):\n        # update sampling params\n        old_sampling_params_args = {}\n        if kwargs:\n            for key, value in kwargs.items():\n                if hasattr(self.sampling_params, key):\n                    old_value = getattr(self.sampling_params, key)\n                    old_sampling_params_args[key] = old_value\n                    setattr(self.sampling_params, key, value)\n        yield\n        # roll back to previous sampling params\n        # if len(old_sampling_params_args):\n        for key, value in old_sampling_params_args.items():\n            setattr(self.sampling_params, key, value)\n\n    @GPUMemoryLogger(role=\"vllm rollout spmd\", logger=logger)\n    @torch.no_grad()\n    def generate_sequences(self, prompts: TensorDict, **kwargs) -> TensorDict:\n        idx = prompts[\"input_ids\"]\n        # left-padded attention_mask \n        attention_mask = prompts[\"attention_mask\"]\n        position_ids = prompts[\"position_ids\"]\n        raw_prompt_ids = prompts.pop(\"raw_prompt_ids\").data\n        \n        # used to construct attention_mask\n        eos_token_id = prompts[\"eos_token_id\"]\n        batch_size = idx.size(0)\n        if batch_size != len(raw_prompt_ids):\n            raise RuntimeError(\"vllm sharding manager is not work properly.\")\n        if \"multi_modal_data\" in prompts:\n            vllm_inputs = []\n            for raw_prompt_id, multi_modal_data in zip(raw_prompt_ids, prompts.pop(\"multi_modal_data\").data):\n                vllm_inputs.append({\"prompt_token_ids\": raw_prompt_id, \"multi_modal_data\": multi_modal_data})\n        else:\n            vllm_inputs = [{\"prompt_token_ids\": raw_prompt_id} for raw_prompt_id in raw_prompt_ids]\n\n        # ensure the type of `prompt_token_ids` passed to vllm is list[int]\n        # https://github.com/volcengine/verl/pull/772\n        for i, input_data in enumerate(vllm_inputs):\n            if isinstance(input_data[\"prompt_token_ids\"], np.ndarray):\n                input_data[\"prompt_token_ids\"] = input_data[\"prompt_token_ids\"].tolist()\n            elif not isinstance(input_data[\"prompt_token_ids\"], list):\n                raise TypeError(f\"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}\")\n\n\n        prompt_lengths = []\n        if self.enbale_perf:\n            # Record prompt length\n            prompt_lengths = [len(item[\"prompt_token_ids\"]) for item in vllm_inputs]\n            device_synchronize()\n            start_time = time.time()\n\n        \n        do_sample = prompts[\"do_sample\"] if \"do_sample\" in prompts else True\n        is_validate = prompts[\"validate\"] if \"validate\" in prompts else False\n\n        if not do_sample:\n            kwargs = {\n                \"best_of\": 1,\n                \"top_p\": 1.0,\n                \"top_k\": -1,\n                \"min_p\": 0.0,\n                \"temperature\": 0,\n                \"n\": 1,  # if greedy, only 1 response\n            }\n        elif is_validate:\n            # TODO: try **\n            kwargs = {\n                \"top_k\": self.config.val_kwargs.top_k,\n                \"top_p\": self.config.val_kwargs.top_p,\n                \"temperature\": self.config.val_kwargs.temperature,\n                \"n\": 1,  # if validate, already repeat in ray_trainer\n            }\n\n        lora_requests = None\n        if self.lora_kwargs:\n            lora_int_ids = list(self.inference_engine.llm_engine.list_loras())\n            if len(lora_int_ids) > 0:\n                lora_int_id = lora_int_ids[0]\n                lora_requests = [LoRARequest(lora_name=f\"{lora_int_id}\", lora_int_id=lora_int_id, lora_path=\"/simon-stub-path\")] * batch_size\n\n        # users can customize different sampling_params at different run\n        with self.update_sampling_params(**kwargs):\n            logger.info(f\"vllm generate sampling params: {self.sampling_params}\")\n            # if micro_batch_size is configured, split the batch into smaller chunks\n            # and generate sequences for each chunk sequentially.\n            if self.micro_batch_size > 0:\n                outputs = []\n                for i in range(0, len(vllm_inputs), self.micro_batch_size):\n                    micro_batch = vllm_inputs[i : i + self.micro_batch_size]\n                    if not micro_batch:\n                        continue\n\n                    micro_outputs = self.inference_engine.generate(\n                        prompts=micro_batch,\n                        sampling_params=self.sampling_params,\n                        use_tqdm=False,\n                    )\n                    outputs.extend(micro_outputs)\n            else:\n                # full-batch inference\n                outputs = self.inference_engine.generate(\n                    prompts=vllm_inputs,  # because we have already convert it to prompt token id\n                    sampling_params=self.sampling_params,\n                    lora_request=lora_requests,\n                    use_tqdm=False,\n                )\n\n            if self.enbale_perf:\n                rank = torch.distributed.get_rank()\n                device_synchronize()\n                inference_latency = time.time() - start_time\n                # Record the length of generated tokens\n                generated_lengths = [len(o.outputs[0].token_ids) for o in outputs]\n                # Compute statistics\n                min_prompt_len = np.min(prompt_lengths) if prompt_lengths else 0\n                max_prompt_len = np.max(prompt_lengths) if prompt_lengths else 0\n                avg_prompt_len = np.mean(prompt_lengths) if prompt_lengths else 0\n\n                min_gen_len = np.min(generated_lengths) if generated_lengths else 0\n                max_gen_len = np.max(generated_lengths) if generated_lengths else 0\n                avg_gen_len = np.mean(generated_lengths) if generated_lengths else 0\n                # Write to CSV file\n                with open(self.perf_log_file, \"a\", newline=\"\") as f:\n                    writer = csv.writer(f)\n                    writer.writerow([self.perf_step, rank, time.strftime(\"%Y-%m-%d %H:%M:%S\"), batch_size, inference_latency, min_prompt_len, max_prompt_len, avg_prompt_len, min_gen_len, max_gen_len, avg_gen_len])\n                self.perf_step += 1\n                logger.info(f\"vllm rollout perf log saved to {self.perf_log_file}\")\n\n            # TODO(sgm): disable logprob when recompute_log_prob is enable\n            # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)\n            total_input_tokens = 0\n            total_output_tokens = 0\n            for input_data in vllm_inputs:\n                total_input_tokens += len(input_data[\"prompt_token_ids\"])\n\n            response = []\n            for output in outputs:\n                for sample_id in range(len(output.outputs)):\n                    response_ids = output.outputs[sample_id].token_ids\n                    response.append(response_ids)\n                    curr_log_prob = []\n                    for i, logprob in enumerate(output.outputs[sample_id].logprobs):\n                        curr_log_prob.append(logprob[response_ids[i]].logprob)\n                total_output_tokens += len(output.outputs[0].token_ids)\n            prompts[\"total_input_tokens\"] = NonTensorData(total_input_tokens, batch_size=None)\n            prompts[\"total_output_tokens\"] = NonTensorData(total_output_tokens, batch_size=None)\n            response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to(idx.device)\n\n            seq = torch.cat([idx, response], dim=-1)\n\n        response_length = response.size(1)\n        delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)\n        delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1)\n        if position_ids.dim() == 3:  # qwen2vl mrope\n            delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)\n\n        # TODO(sgm): fix position_ids on right_pad\n        # prompt: left pad + response: right pad\n        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]\n        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]\n        response_position_ids = position_ids[..., -1:] + delta_position_id\n        position_ids = torch.cat([position_ids, response_position_ids], dim=-1)\n        response_attention_mask = get_response_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)\n        attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)\n        # left-padded attention_mask \n        # all the tp ranks should contain the same data here. data in all ranks are valid\n        prompts[\"prompts\"] = idx\n        prompts[\"attention_mask\"] = attention_mask\n        prompts[\"position_ids\"] = position_ids\n        prompts[\"responses\"] = response\n        prompts[\"input_ids\"] = seq  # here input_ids become the whole sentences\n        return prompts\n\n\n\n\n# https://github.com/vllm-project/vllm/issues/13175\ndef _monkey_patch_compute_logits(model, vocab_size: int):\n    original_compute_logits = model.compute_logits\n\n    def compute_logits(\n        self,\n        hidden_states: torch.Tensor,\n        sampling_metadata: SamplingMetadata,\n    ) -> torch.Tensor:\n        logits = original_compute_logits(hidden_states, sampling_metadata)\n        logits[..., vocab_size:] = float(\"-inf\")\n        return logits\n\n    model.compute_logits = MethodType(compute_logits, model)\n\nclass vLLMAsyncRollout:\n    \"\"\"vLLMAsyncRollout is a thin wrapper of WorkerWrapperBase,\n    which is engine in single worker process.\n    \"\"\"\n    def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs):\n        self.tokenizer = tokenizer\n\n        # Engine is deferred to be initialized in init_worker\n        self.config = config\n        self.inference_engine: WorkerWrapperBase = None\n        self.sharding_manager = None\n        self.is_sleep = False\n        self.address = self._init_zeromq()\n\n    def _init_zeromq(self) -> str:\n        tensor_parallel_size = self.config.tensor_model_parallel_size\n\n        # single node: ipc, multi nodes: tcp\n        local_world_size = int(os.environ[\"RAY_LOCAL_WORLD_SIZE\"])\n        socket_type = \"ipc\" if tensor_parallel_size <= local_world_size else \"tcp\"\n\n        # File lock to prevent multiple workers listen to same port\n        with FileLock(\"/tmp/siirl_vllm_zmq.lock\"):\n            if socket_type == \"ipc\":\n                pid = os.getpid()\n                address = f\"ipc:///tmp/siirl_vllm_zmq_{pid}.ipc\"\n            else:\n                ip, port = self._get_free_port()\n                address = f\"tcp://{ip}:{port}\"\n            context = zmq.Context()\n            self.socket = context.socket(zmq.REP)\n            self.socket.bind(address)\n\n        self.loop_thread = threading.Thread(target=self._loop_forever)\n        self.loop_thread.start()\n        return address\n\n    def _get_free_port(self):\n        ip = ray._private.services.get_node_ip_address()\n        with socket.socket() as sock:\n            sock.bind((\"\", 0))\n            port = sock.getsockname()[1]\n        return ip, port\n\n    def _loop_forever(self):\n        while True:\n            message = self.socket.recv()\n            method, args, kwargs = pickle.loads(message)\n            result = self.execute_method(method, *args, **kwargs)\n            self.socket.send(pickle.dumps(result))\n\n    def get_zeromq_address(self):\n        return self.address\n\n    def init_worker(self, all_kwargs: List[Dict[str, Any]]):\n        \"\"\"Initialize worker engine.\"\"\"\n        all_kwargs[0][\"rank\"] = int(os.environ[\"RANK\"])\n        all_kwargs[0][\"local_rank\"] = 0\n\n        self.vllm_config = all_kwargs[0][\"vllm_config\"]\n        self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config)\n        self.inference_engine.init_worker(all_kwargs)\n\n\n    def load_model(self, *args, **kwargs):\n        self.inference_engine.load_model(*args, **kwargs)\n\n        # inference engine is initialized now, update sharding manager\n        self.sharding_manager.inference_engine = self.inference_engine\n        self.sharding_manager.model_runner = self.inference_engine.worker.model_runner\n\n        _monkey_patch_compute_logits(self.inference_engine.worker.model_runner.model, len(self.tokenizer))\n    def sleep(self, *args, **kwargs):\n        \"\"\"Offload model weights and discard kv cache.\"\"\"\n        if self.is_sleep:\n            return\n        self.sharding_manager.__exit__(None, None, None)\n        self.is_sleep = True\n\n    def wake_up(self, *args, **kwargs):\n        \"\"\"Load model weights and build kv cache.\"\"\"\n        if not self.is_sleep:\n            return\n        self.sharding_manager.__enter__()  # pylint: disable=C2801\n        self.is_sleep = False\n\n    def execute_method(self, method: Union[str, bytes], *args, **kwargs):\n        if method == \"init_worker\":\n            return self.init_worker(*args, **kwargs)\n        elif method == \"load_model\":\n            return self.load_model(*args, **kwargs)\n        elif method == \"sleep\":\n            return self.sleep(*args, **kwargs)\n        elif method == \"wake_up\":\n            return self.wake_up(*args, **kwargs)\n        else:\n            return self.inference_engine.execute_method(method, *args, **kwargs)\n"
  },
  {
    "path": "siirl/engine/sharding_manager/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base import BaseShardingManager\nfrom .fsdp_hf import FSDPHFShardingManager\n# from .fsdp_vllm import MultiAgentFSDPVLLMShardingManager\n\n__all__ = [\n    \"BaseShardingManager\",\n    \"FSDPHFShardingManager\",\n]\n"
  },
  {
    "path": "siirl/engine/sharding_manager/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nSharding manager to implement HybridEngine\n\"\"\"\n\nfrom tensordict import TensorDict\n\n\nclass BaseShardingManager:\n    def __enter__(self):\n        pass\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        pass\n\n    def preprocess_data(self, data: TensorDict) -> TensorDict:\n        return data\n\n    def postprocess_data(self, data: TensorDict) -> TensorDict:\n        return data\n"
  },
  {
    "path": "siirl/engine/sharding_manager/fsdp_hf.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nShardingManager for FSDP + HF Rollout (including EmbodiedHFRollout).\nManages model loading/offloading between training (actor) and inference (rollout).\n\"\"\"\n\nfrom loguru import logger\nfrom torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\n\nfrom siirl.utils.extras.device import get_torch_device\nfrom siirl.utils.model_utils.fsdp_utils import load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu\nfrom siirl.engine.sharding_manager.base import BaseShardingManager\n\n\nclass FSDPHFShardingManager(BaseShardingManager):\n    \"\"\"\n    ShardingManager for FSDP + HuggingFace Rollout.\n    \n    This manager handles model offloading for HF-based rollout (including EmbodiedHFRollout).\n    - In __enter__: Load actor model (and embedding model if needed) to GPU before rollout\n    - In __exit__: Offload actor model (and embedding model) to CPU after rollout\n    \n    This follows the same pattern as MultiAgentFSDPVLLMShardingManager and \n    MultiAgentFSDPSGLangShardingManager for consistency.\n    \"\"\"\n    \n    def __init__(\n        self, \n        module: FSDP, \n        rollout, \n        offload_param: bool = False,\n        offload_embedding: bool = False\n    ):\n        \"\"\"\n        Initialize FSDP HF Sharding Manager.\n        \n        Args:\n            module: The FSDP-wrapped actor model (actor_module_fsdp)\n            rollout: The rollout object (HFRollout or EmbodiedHFRollout)\n            offload_param: Whether to offload actor model parameters\n            offload_embedding: Whether to offload embedding model (for EmbodiedHFRollout)\n        \"\"\"\n        self.module = module\n        self.rollout = rollout\n        self.offload_param = offload_param\n        self.offload_embedding = offload_embedding\n        \n        # Track state\n        self.is_asleep = False  # Model starts on GPU after initialization\n        \n        logger.info(\n            f\"FSDPHFShardingManager initialized: \"\n            f\"offload_param={offload_param}, offload_embedding={offload_embedding}\"\n        )\n    \n    def __enter__(self):\n        \"\"\"\n        Called before rollout generation.\n        Load models to GPU if they were offloaded.\n        \"\"\"\n        if not self.is_asleep:\n            # Models already on GPU (first time or previous rollout didn't offload)\n            return\n        \n        # 1. Load actor model to GPU\n        if self.offload_param:\n            load_fsdp_model_to_gpu(self.module)\n        \n        # 2. Load embedding model to GPU (for EmbodiedHFRollout)\n        if self.offload_embedding:\n            self.rollout.embedding_model.load_to_device()\n        \n        self.is_asleep = False\n    \n    def __exit__(self, exc_type, exc_value, traceback):\n        \"\"\"\n        Called after rollout generation.\n        Offload models to CPU to free GPU memory.\n        \"\"\"\n        if self.is_asleep:\n            # Already offloaded\n            return\n        \n        # 1. Offload embedding model first (for EmbodiedHFRollout)\n        if self.offload_embedding:\n            self.rollout.embedding_model.offload_to_host()\n        \n        # 2. Offload actor model to CPU\n        if self.offload_param:\n            offload_fsdp_model_to_cpu(self.module)\n        \n        # 3. Clear cache\n        get_torch_device().empty_cache()\n        \n        self.is_asleep = True\n\n"
  },
  {
    "path": "siirl/engine/sharding_manager/fsdp_sglang.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Union\n\nfrom loguru import logger\nimport torch\nimport asyncio\nimport torch.distributed as dist\nfrom sglang.srt.entrypoints.engine import Engine\nfrom sglang.srt.model_executor.model_runner import LocalSerializedTensor\nfrom sglang.srt.utils import MultiprocessingSerializer\nfrom torch.distributed.device_mesh import DeviceMesh\nfrom torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType\nfrom torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\nfrom torch.distributed.tensor import DTensor\n\nfrom tensordict import TensorDict\nfrom siirl.data_coordinator.protocol import all_gather_data_proto\nfrom siirl.utils.debug import log_gpu_memory_usage\nfrom siirl.utils.extras.device import get_device_id, get_device_name, get_torch_device\nfrom siirl.utils.model_utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu\nfrom siirl.utils.model_utils.torch_functional import broadcast_dict_tensor, check_device_is_available\nfrom loguru import logger\nfrom siirl.engine.sharding_manager.base import BaseShardingManager\nfrom siirl.utils.extras.device import get_device_id, get_torch_device\nfrom siirl.utils.model_utils.model import convert_weight_keys\n# from vllm.distributed import parallel_state as sglang_ps\n\n\ndef _preprocess_tensor_for_update_weights(tensor: torch.Tensor):\n    if isinstance(tensor, DTensor):\n        return tensor.full_tensor()\n    return tensor\n\n\nclass MultiAgentFSDPSGLangShardingManager(BaseShardingManager):\n    @check_device_is_available()\n    def __init__(\n        self,\n        module: FSDP,\n        inference_engine:  Engine,\n        model_config,\n        device_mesh: torch.distributed.DeviceMesh, \n        rollout_config: dict[str, int],\n        full_params: bool = False,\n        offload_param: bool = False,\n        multi_stage_wake_up: bool = False\n    ):\n        self.module = module\n        self.inference_engine = inference_engine\n        self.model_config = model_config\n        self.device_mesh = device_mesh\n        self.offload_param = offload_param\n\n        # Full params\n        self.full_params = full_params\n        if full_params and fsdp_version(self.module) == 1:\n            FSDP.set_state_dict_type(self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig())\n        elif fsdp_version(self.module) == 1:\n            FSDP.set_state_dict_type(\n                self.module,\n                state_dict_type=StateDictType.SHARDED_STATE_DICT,\n                state_dict_config=ShardedStateDictConfig(),\n            )\n\n        self.tp_size = self.device_mesh[\"infer_tp\"].size()\n        self.tp_rank = self.device_mesh[\"infer_tp\"].get_local_rank()\n\n        # Note that torch_random_states may be different on each dp rank\n        self.torch_random_states = torch.cuda.get_rng_state()\n        # get a random rng states\n        if self.device_mesh is not None:\n            gen_dp_rank = self.device_mesh[\"dp\"].get_local_rank()\n            torch.cuda.manual_seed(gen_dp_rank + 1000)  # make sure all tp ranks have the same random states\n            self.gen_random_states = torch.cuda.get_rng_state()\n            torch.cuda.set_rng_state(self.torch_random_states)\n        else:\n            self.gen_random_states = None\n\n    def __enter__(self):\n        torch.cuda.empty_cache()\n        log_gpu_memory_usage(\"Before state_dict() in sharding manager memory\", logger=logger)\n        if self.offload_param:\n            load_fsdp_model_to_gpu(self.module)\n        params = self.module.state_dict()\n        log_gpu_memory_usage(\"After state_dict() in sharding manager memory\", logger=logger)\n        device = torch.cuda.current_device()  # used when fsdp2 set cpu_offload_policy\n        params = {k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()}\n        # Copy, not share memory\n        loop = asyncio.get_event_loop()\n        loop.run_until_complete(self.update_weights(params))\n        log_gpu_memory_usage(\"After sync model weights in sharding manager\", logger=logger)\n\n        del params\n        if self.offload_param:\n            offload_fsdp_model_to_cpu(self.module)\n        torch.cuda.empty_cache()\n        log_gpu_memory_usage(\"After del state_dict and empty_cache in sharding manager\", logger=logger)\n\n        # important: need to manually set the random states of each tp to be identical.\n        if self.device_mesh is not None:\n            self.torch_random_states = torch.cuda.get_rng_state()\n            torch.cuda.set_rng_state(self.gen_random_states)\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        log_gpu_memory_usage(\"Before SGLang offload in sharding manager\", logger=logger)\n        loop = asyncio.get_event_loop()\n        loop.run_until_complete(self.release_memory())\n        log_gpu_memory_usage(\"After SGLang offload in sharding manager\", logger=logger)\n\n        self.module.train()\n\n        # add empty cache after each compute\n        torch.cuda.empty_cache()\n\n        # restore random states\n        if self.device_mesh is not None:\n            self.gen_random_states = torch.cuda.get_rng_state()\n            torch.cuda.set_rng_state(self.torch_random_states)\n\n    async def update_weights(self, params):\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0:\n            await self.inference_engine.resume_memory_occupation()\n\n        # Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update\n        named_tensors = [(k, v) for k, v in params.items()]\n        load_format = None\n        for tensor_index, (name, tensor) in enumerate(named_tensors):\n            serialized_tensor = MultiprocessingSerializer.serialize(_preprocess_tensor_for_update_weights(tensor))\n\n            if self.device_mesh[\"infer_tp\"].get_local_rank() == 0:\n                gathered_serialized_tensors = [None for _ in range(self.device_mesh[\"infer_tp\"].mesh.size()[0])]\n            else:\n                gathered_serialized_tensors = None\n            dist.gather_object(\n                obj=serialized_tensor,\n                object_gather_list=gathered_serialized_tensors,\n                dst=self.device_mesh[\"infer_tp\"].mesh.tolist()[0],\n                group=self.device_mesh[\"infer_tp\"].get_group(),\n            )\n\n            if self.device_mesh[\"infer_tp\"].get_local_rank() == 0:\n                await self.inference_engine.update_weights_from_tensor(\n                    named_tensors=[\n                        (\n                            name,\n                            LocalSerializedTensor(values=gathered_serialized_tensors),\n                        )\n                    ],\n                    load_format=load_format,\n                    flush_cache=tensor_index == len(named_tensors) - 1,\n                )\n\n    async def release_memory(self):\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0:\n            await self.inference_engine.release_memory_occupation()\n\n    async def wake_up(self):\n        torch.cuda.empty_cache()\n        log_gpu_memory_usage(\"Before state_dict() in sharding manager memory\", logger=logger)\n        if self.offload_param:\n            load_fsdp_model_to_gpu(self.module)\n        params = self.module.state_dict()\n        log_gpu_memory_usage(\"After state_dict() in sharding manager memory\", logger=logger)\n        device = torch.cuda.current_device()  # used when fsdp2 set cpu_offload_policy\n        params = {k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()}\n        # Copy, not share memory\n        await self.update_weights(params)\n        log_gpu_memory_usage(\"After sync model weights in sharding manager\", logger=logger)\n\n        del params\n        if self.offload_param:\n            offload_fsdp_model_to_cpu(self.module)\n        torch.cuda.empty_cache()\n        log_gpu_memory_usage(\"After del state_dict and empty_cache in sharding manager\", logger=logger)\n\n        # important: need to manually set the random states of each tp to be identical.\n        if self.device_mesh is not None:\n            self.torch_random_states = torch.cuda.get_rng_state()\n            torch.cuda.set_rng_state(self.gen_random_states)\n\n    async def sleep(self):\n        log_gpu_memory_usage(\"Before SGLang offload in sharding manager\", logger=logger)\n        await self.release_memory()\n        log_gpu_memory_usage(\"After SGLang offload in sharding manager\", logger=logger)\n\n        self.module.train()\n\n        # add empty cache after each compute\n        torch.cuda.empty_cache()\n\n        # restore random states\n        if self.device_mesh is not None:\n            self.gen_random_states = torch.cuda.get_rng_state()\n            torch.cuda.set_rng_state(self.torch_random_states)\n\n    def preprocess_data(self, data: TensorDict) -> TensorDict:\n        \"\"\"All gather across tp group to make each rank has identical input.\"\"\"\n        if self.tp_size == 1:\n            return data\n\n        # TODO: Current impl doesn't consider FSDP with torch micro-dp\n        group = self.device_mesh[\"infer_tp\"].get_group()\n\n        all_gather_data_proto(data=data, process_group=group)\n        return data\n\n    def postprocess_data(self, data: TensorDict) -> TensorDict:\n        \"\"\"Get chunk data of this tp rank since we do all gather in preprocess.\"\"\"\n        if self.tp_size == 1:\n            return data\n\n        return data.chunk(chunks=self.tp_size)[self.tp_rank]\n"
  },
  {
    "path": "siirl/engine/sharding_manager/fsdp_ulysses.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nContains a resharding manager that binds weights from FSDP zero3 to XPerfGPT\n\"\"\"\nfrom tensordict import TensorDict\nfrom torch.distributed.device_mesh import DeviceMesh\n\nfrom siirl.data_coordinator.protocol import all_gather_data_proto\nfrom siirl.utils.model_utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group\n\nfrom siirl.engine.sharding_manager.base import BaseShardingManager\n\n\nclass FSDPUlyssesShardingManager(BaseShardingManager):\n    \"\"\"\n    Sharding manager to support data resharding when using FSDP + Ulysses\n    \"\"\"\n\n    def __init__(self, device_mesh: DeviceMesh):\n        super().__init__()\n        self.device_mesh = device_mesh\n        self.seed_offset = 12345\n\n    def __enter__(self):\n        if self.device_mesh is not None:\n            # We have a global SP group\n            # so we have to change to use model-specific sp group\n            self.prev_sp_group = get_ulysses_sequence_parallel_group()\n            set_ulysses_sequence_parallel_group(self.device_mesh[\"sp\"].get_group())\n            # TODO: check how to set seed for each model\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        # restore random states\n        if self.device_mesh is not None:\n            # revert to previous sp group\n            set_ulysses_sequence_parallel_group(self.prev_sp_group)\n            # TODO: check how to set seed for each model\n\n    def preprocess_data(self, data: TensorDict) -> TensorDict:\n        \"\"\"\n        AllGather data from sp region\n        This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE\n        In Ulysses, we need to make sure the same data is used across a SP group\n        \"\"\"\n        if self.device_mesh is not None:\n            group = self.device_mesh[\"sp\"].get_group()\n\n            all_gather_data_proto(data=data, process_group=group)\n        return data\n\n    def postprocess_data(self, data: TensorDict) -> TensorDict:\n        \"\"\"\n        Split the data to follow FSDP partition\n        \"\"\"\n        if self.device_mesh is not None:\n            sp_size = self.device_mesh[\"sp\"].size()\n            sp_rank = self.device_mesh[\"sp\"].get_local_rank()\n            data = data.chunk(chunks=sp_size)[sp_rank]\n        return data\n"
  },
  {
    "path": "siirl/engine/sharding_manager/fsdp_vllm.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport inspect\nimport time\nfrom collections import OrderedDict\n\nfrom loguru import logger\nfrom torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType\nfrom torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\n\ntry:\n    # for torch 2.5+\n    from torch.distributed.tensor import DTensor\nexcept ImportError:\n    from torch.distributed._tensor import DTensor\n\nfrom dataclasses import asdict\n\nfrom vllm import LLM\n\nfrom siirl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage\nfrom siirl.utils.extras.device import get_device_id, get_device_name, get_torch_device\nfrom siirl.utils.model_utils.fsdp_utils import fsdp_version, layered_summon_lora_params, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu\nfrom siirl.utils.model_utils.torch_functional import check_device_is_available\nfrom siirl.utils.model_utils.vllm_utils import TensorLoRARequest, VLLMHijack, patch_vllm_moe_model_weight_loader\nfrom siirl.engine.sharding_manager.base import BaseShardingManager\n\n\nclass MultiAgentFSDPVLLMShardingManager(BaseShardingManager):\n    @check_device_is_available()\n    def __init__(self, module: FSDP, inference_engine: LLM, model_config, parallel_config: dict[str, int], full_params: bool = False, offload_param: bool = False, load_format: str = \"dummy_hf\", layered_summon: bool = True):\n        self.module = module\n        # For AsyncLLM, inference_engine and model_runner are defer initialized in vLLMAsyncRollout.load_model\n        self.inference_engine = inference_engine\n        # vLLM > v0.6.3\n        self.model_runner = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if self.inference_engine else None\n\n        self.model_config = model_config\n        self.parallel_config = parallel_config\n        self.offload_param = offload_param\n        self.load_format = load_format\n        self.layered_summon = layered_summon\n\n        # Full params\n        self.full_params = full_params\n        if full_params and fsdp_version(self.module) == 1:\n            FSDP.set_state_dict_type(self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig())\n        elif fsdp_version(self.module) == 1:\n            FSDP.set_state_dict_type(\n                self.module,\n                state_dict_type=StateDictType.SHARDED_STATE_DICT,\n                state_dict_config=ShardedStateDictConfig(),\n            )\n\n        self.tp_size = self.parallel_config[\"rollout_parallel_size\"]\n        self.world_size = self.parallel_config[\"rollout_world_size\"]\n        self.rank = self.parallel_config[\"rollout_rank\"]\n        self.tp_rank = self.rank % self.tp_size\n\n        # Note that torch_random_states may be different on each dp rank\n        self.torch_random_states = get_torch_device().get_rng_state()\n        # get a random rng states\n        gen_dp_rank = self.rank // self.tp_size\n        get_torch_device().manual_seed(gen_dp_rank + 1000)  # make sure all tp ranks have the same random states\n        self.gen_random_states = get_torch_device().get_rng_state()\n        get_torch_device().set_rng_state(self.torch_random_states)\n\n        self.base_sync_done: bool = \"dummy\" not in load_format\n        # vllm >= 0.7.3\n        VLLMHijack.hijack()\n\n    @GPUMemoryLogger(role=\"fsdp vllm sharding_manager\", logger=logger)\n    def __enter__(self):\n        def __collect_lora_params() -> OrderedDict:\n            \"\"\"\n            collect lora params or full params if base model is not ready in vllm\n            work with if isinstance(self.module._fsdp_wrapped_module, PeftModel)\n            \"\"\"\n            from peft.utils.save_and_load import get_peft_model_state_dict\n\n            lora_params = OrderedDict()\n            peft_model = getattr(self.module, \"_fsdp_wrapped_module\", self.module)\n            if fsdp_version(self.module) > 0:\n                if self.layered_summon:\n                    if not self.base_sync_done:\n                        raise ValueError(\"To use layered_summon, you must make sure base-model is preloaded in vllm, e.g. let rollout.load_format=safetensors\")\n                    lora_params = layered_summon_lora_params(self.module)\n                else:\n                    with FSDP.summon_full_params(self.module, writeback=False):\n                        if self.base_sync_done:\n                            lora_params = get_peft_model_state_dict(peft_model)\n                            lora_params = {name: param.full_tensor().detach().cpu() if hasattr(param, \"full_tensor\") else param.detach().cpu() for name, param in lora_params.items()}\n                        else:\n                            model = peft_model.base_model.model\n                            orig_dev = \"cpu\" if \"cpu\" in next(model.parameters()).device else get_device_name()\n                            model = model.to(\"cpu\")\n                            for name, param in model.state_dict().items():\n                                if any(x in name for x in [\"_flat_param\", \"lora_\"]):\n                                    continue\n                                name = name.replace(\"_fsdp_wrapped_module.\", \"\").replace(\".base_layer\", \"\")\n                                lora_params[name] = param.full_tensor().detach().cpu() if hasattr(param, \"full_tensor\") else param.detach().cpu()\n                            model = model.to(orig_dev)\n                    get_torch_device().empty_cache()\n            else:\n                if self.base_sync_done:\n                    lora_params = get_peft_model_state_dict(peft_model)\n                else:\n                    model = peft_model.base_model.model\n                    orig_dev = \"cpu\" if \"cpu\" in next(model.parameters()).device else get_device_name()\n                    model = model.to(\"cpu\")\n                    for name, param in model.state_dict().items():\n                        if any(x in name for x in [\"_flat_param\", \"lora_\"]):\n                            continue\n                        name = name.replace(\"_fsdp_wrapped_module.\", \"\").replace(\".base_layer\", \"\")\n                        lora_params[name] = param.detach().cpu()\n                    model = model.to(orig_dev)\n            return lora_params\n\n        # NOTE: Basically, we only need `get_torch_device().empty_cache()` before vllm wake_up and\n        # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.\n        # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory\n        # to speed up memory allocations.\n        #\n        # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management\n        # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103\n        get_torch_device().empty_cache()\n\n        log_gpu_memory_usage(\"Before state_dict() in sharding manager memory\", logger=logger)\n        if self.offload_param:\n            load_fsdp_model_to_gpu(self.module)\n\n        peft_config = None\n        peft_model = getattr(self.module, \"_fsdp_wrapped_module\", self.module)\n        if hasattr(peft_model, \"peft_config\"):\n            peft_config = peft_model.peft_config.get(\"default\", None)\n            params = __collect_lora_params()\n        else:\n            params = self.module.state_dict()\n        log_gpu_memory_usage(\"After state_dict() in sharding manager memory\", logger=logger)\n\n        # Copy, not share memory\n        load_format = \"hf\" if self.full_params else \"dtensor\"\n\n        if \"tags\" in inspect.signature(self.inference_engine.wake_up).parameters:\n            self.inference_engine.wake_up(tags=[\"weights\"])\n        else:\n            self.inference_engine.wake_up()\n\n        # update model params\n        self.update_params(params, peft_config=peft_config)\n        log_gpu_memory_usage(\"After sync model weights in sharding manager\", logger=logger)\n        del params\n        if self.offload_param:\n            offload_fsdp_model_to_cpu(self.module)\n        get_torch_device().empty_cache()\n\n        if \"tags\" in inspect.signature(self.inference_engine.wake_up).parameters:\n            self.inference_engine.wake_up(tags=[\"kv_cache\"])\n\n        log_gpu_memory_usage(\"After del state_dict and empty_cache in sharding manager\", logger=logger)\n\n        # important: need to manually set the random states of each tp to be identical.\n        self.torch_random_states = get_torch_device().get_rng_state()\n        get_torch_device().set_rng_state(self.gen_random_states)\n\n    @GPUMemoryLogger(role=\"fsdp vllm sharding_manager\", logger=logger)\n    def __exit__(self, exc_type, exc_value, traceback):\n        # TODO(ZSL): check this\n        self.inference_engine.sleep(level=1)\n\n        self.module.train()\n\n        # add empty cache after each compute\n        get_torch_device().empty_cache()\n\n        # restore random states\n        self.gen_random_states = get_torch_device().get_rng_state()\n        get_torch_device().set_rng_state(self.torch_random_states)\n\n    def update_params(self, updated_params, peft_config=None):\n        model = self.model_runner.model\n        if peft_config:\n            if self.base_sync_done:\n                lora_int_id = int(time.time_ns() % 0x7FFFFFFF)\n                lora_reqest = TensorLoRARequest(\n                    lora_name=f\"{lora_int_id}\",\n                    lora_int_id=lora_int_id,\n                    lora_path=\"simon_lora_path\",\n                    peft_config=asdict(peft_config),\n                    lora_tensors=updated_params,\n                )\n                self.inference_engine.llm_engine.add_lora(lora_reqest)\n                logger.info(f\"vLLM load weights, loaded_params: {len(updated_params)}\")\n                return\n            else:\n\n                def replace_lora_wrapper(k):\n                    stacked_params = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]\n                    if any([k.endswith(f\"{s}.weight\") for s in stacked_params]):\n                        return k.replace(\".weight\", \".base_layer.weight\")\n                    if any([k.endswith(f\"{s}.bias\") for s in stacked_params]):\n                        return k.replace(\".bias\", \".base_layer.bias\")\n                    return k\n\n                updated_params = {replace_lora_wrapper(k): v for k, v in updated_params.items()}\n\n        patch_vllm_moe_model_weight_loader(model)\n        device = get_device_id()  # used when fsdp2 set cpu_offload_policy\n        loaded_params = model.load_weights(((name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) for name, param in updated_params.items()))\n\n        self.base_sync_done = True\n        import torch\n        logger.info(f\"{torch.distributed.get_rank()} vLLM load weights, loaded_params: {len(loaded_params) if loaded_params else -1}\")\n"
  },
  {
    "path": "siirl/engine/sharding_manager/megatron_sglang.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine.\n\"\"\"\n\nimport asyncio\n\nfrom sglang.srt.entrypoints.engine import Engine\nimport torch\nfrom torch import nn\nfrom torch.distributed.device_mesh import DeviceMesh\nfrom sglang.srt.utils import MultiprocessingSerializer\nfrom sglang.srt.model_executor.model_runner import LocalSerializedTensor\nimport torch.distributed as dist\nfrom torch.distributed.tensor import DTensor\n\nfrom siirl.data_coordinator.protocol import  all_gather_data_proto\nfrom siirl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage\nfrom siirl.utils.extras.device import get_device_id, get_device_name, get_torch_device, set_expandable_segments\nfrom siirl.utils.megatron.megatron_utils import (\n    per_tensor_generator,\n    load_megatron_model_to_gpu,\n    offload_megatron_model_to_cpu,\n)\nfrom siirl.utils.memory_utils import aggressive_empty_cache\n\nfrom siirl.engine.sharding_manager.base import BaseShardingManager\nfrom loguru import logger\n\n\n\"\"\"\nMegatron Hybrid Engine:\n- During training, only the current pp stage holds the parameters\n- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters)\n- Bind the parameters to the inference engine\n- Do inference in tp. pp is treated as additional dp\n- After inference, all the parameters that doesn't belong to this pp rank is freed.\n\"\"\"\n\ndef _preprocess_tensor_for_update_weights(tensor: torch.Tensor):\n    if isinstance(tensor, DTensor):\n        return tensor.full_tensor()\n    return tensor\n\nclass MultiAgentMegatronSGLangShardingManager(BaseShardingManager):\n    def __init__(\n        self,\n        actor_module: nn.ModuleList,\n        inference_engine: Engine,\n        model_config,\n        rollout_config,\n        transformer_config,\n        layer_name_mapping,\n        weight_converter,\n        device_mesh: DeviceMesh | None = None,\n        offload_param: bool = False,\n        bridge=None,\n    ):\n        self.actor_module = actor_module\n        self.inference_engine = inference_engine\n        self.model_config = model_config\n        self.rollout_config = rollout_config\n        self.transformer_config = transformer_config\n        self.layer_name_mapping = layer_name_mapping\n        self.weight_converter = weight_converter\n        self.device_mesh = device_mesh\n        self.offload_param = offload_param\n        self.bridge = bridge\n\n        if self.device_mesh is not None:\n            self.infer_tp_size = self.device_mesh[\"infer_tp\"].size()\n        else:\n            self.infer_tp_size = self.inference_engine._tp_size\n\n        # Note that torch_random_states may be different on each dp rank\n        self.torch_random_states = get_torch_device().get_rng_state()\n        # get a random rng states\n        if self.device_mesh is not None:\n            gen_dp_rank = self.device_mesh[\"dp\"].get_local_rank()\n            get_torch_device().manual_seed(gen_dp_rank + 1000)  # make sure all tp ranks have the same random states\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n        else:\n            self.gen_random_states = None\n\n    @GPUMemoryLogger(role=\"MultiAgentMegatronSGLangShardingManager enter\", logger=logger)\n    def __enter__(self):\n        loop = asyncio.get_event_loop()\n        loop.run_until_complete(self.wake_up())\n\n    @GPUMemoryLogger(role=\"MultiAgentMegatronSGLangShardingManager exit\", logger=logger)\n    def __exit__(self, exc_type, exc_value, traceback):\n        loop = asyncio.get_event_loop()\n        loop.run_until_complete(self.sleep())\n\n    async def update_weights(self, params):\n        \"\"\"\n        Update model weights using tensor buckets, similar to THUDM/slime's implementation.\n\n        Notes:\n          - For the best performance of `rebuild_cuda_tensor`, it is recommended to:\n              1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES`.\n              2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7`\n            when using Tensor Parallelism (TP >= 8).\n          - See reference implementations in SLIME:\n            - Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452\n            - runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39\n        \"\"\"\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:\n            await self.inference_engine.resume_memory_occupation()\n        named_tensors = params\n\n        load_format = None\n        for tensor_index, (name, tensor) in enumerate(named_tensors):\n            serialized_tensor = MultiprocessingSerializer.serialize(_preprocess_tensor_for_update_weights(tensor))\n\n            if self.device_mesh[\"infer_tp\"].get_local_rank() == 0:\n                gathered_serialized_tensors = [None for _ in range(self.device_mesh[\"infer_tp\"].mesh.size()[0])]\n            else:\n                gathered_serialized_tensors = None\n            dist.gather_object(\n                obj=serialized_tensor,\n                object_gather_list=gathered_serialized_tensors,\n                dst=self.device_mesh[\"infer_tp\"].mesh.tolist()[0],\n                group=self.device_mesh[\"infer_tp\"].get_group(),\n            )\n\n            if self.device_mesh[\"infer_tp\"].get_local_rank() == 0:\n                await self.inference_engine.update_weights_from_tensor(\n                    named_tensors=[\n                        (\n                            name,\n                            LocalSerializedTensor(values=gathered_serialized_tensors),\n                        )\n                    ],\n                    load_format=load_format,\n                    flush_cache=False,\n                )\n                \n            if self.device_mesh[\"infer_tp\"].get_local_rank() == 0:\n                await self.inference_engine.flush_cache()\n\n    async def release_memory(self):\n        if self.device_mesh[\"infer_tp\"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:\n            await self.inference_engine.release_memory_occupation()\n\n    @GPUMemoryLogger(role=\"MultiAgentMegatronSGLangShardingManager enter\", logger=logger)\n    async def wake_up(self):\n        aggressive_empty_cache(force_sync=True)\n\n        if self.offload_param:\n            load_megatron_model_to_gpu(self.actor_module, load_grad=False)\n        if self.bridge is not None:\n            per_tensor_param = self.bridge.export_weights(self.actor_module)\n        else:\n            per_tensor_param = per_tensor_generator(\n                self.actor_module,\n                self.model_config,\n                self.weight_converter,\n                self.transformer_config,\n                self.layer_name_mapping,\n            )\n\n        set_expandable_segments(False)\n\n        await self.update_weights(per_tensor_param)\n        if self.offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n        aggressive_empty_cache(force_sync=True)\n        # important: need to manually set the random states of each tp to be identical.\n        if self.device_mesh is not None:\n            self.torch_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.gen_random_states)\n\n    @GPUMemoryLogger(role=\"MultiAgentMegatronSGLangShardingManager exit\", logger=logger)\n    async def sleep(self):\n        if self.rollout_config.free_cache_engine:\n            log_gpu_memory_usage(\"Before SGLang offload in sharding manager\", logger=logger)\n            await self.release_memory()\n            log_gpu_memory_usage(\"After SGLang offload in sharding manager\", logger=logger)\n\n        for model in self.actor_module:\n            model.train()\n        # add empty cache after each compute\n        aggressive_empty_cache(force_sync=True)\n\n        set_expandable_segments(True)\n\n        # restore random states\n        if self.device_mesh is not None:\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n"
  },
  {
    "path": "siirl/engine/sharding_manager/megatron_vllm.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine.\n\"\"\"\n\nimport inspect\nimport gc\n\nimport torch\nimport torch.distributed\nimport torch.distributed as dist\nfrom megatron.core import DistributedDataParallel as LocalDDP\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core.transformer.module import Float16Module\nfrom torch import nn\nfrom torch.nn.parallel.distributed import DistributedDataParallel as torchDDP\n\nfrom siirl.models.mcore.weight_converter import McoreToHFWeightConverterBase\nfrom vllm import LLM\nfrom siirl.utils.debug import GPUMemoryLogger\nfrom siirl.utils.extras.device import get_torch_device, set_expandable_segments\nfrom siirl.utils.megatron.megatron_utils import (\n    get_model,\n    per_tensor_generator,\n    unwrap_model,\n    load_megatron_model_to_gpu,\n    offload_megatron_model_to_cpu,\n)\nfrom siirl.utils.megatron.memory_buffer import (\n    build_memory_buffer,\n    build_memory_reference_from_module,\n    get_weight_buffer_meta_from_module,\n)\nfrom siirl.utils.model_utils.torch_functional import check_device_is_available\nfrom siirl.utils.model_utils.vllm_utils import patch_vllm_moe_model_weight_loader\nfrom siirl.utils.memory_utils import aggressive_empty_cache\n\nfrom siirl.engine.sharding_manager.base import BaseShardingManager\n\nfrom loguru import logger\n\n\nclass AllGatherPPModel:\n    def __init__(self, model_provider, use_distributed_optimizer=True) -> None:\n        print(\n            \"[WARNING] This class is deprecated and will no longer be supported. \\\nConsider using the `MegatronPPOActor` class directly as a replacement.\"\n        )\n        self._pp_group = mpu.get_pipeline_model_parallel_group()\n        self._pp_rank = mpu.get_pipeline_model_parallel_rank()\n        self._pp_size = mpu.get_pipeline_model_parallel_world_size()\n        self._vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n        self._model_chunk_size = self._vpp_size or 1\n\n        # each one holds a list of model_chunks in this pp stage\n        self._pp_models = [None] * self.pp_size\n\n        rank_list = list(range(self.pp_size))\n        # make current rank the last one to initialize\n        rank_list[self.pp_rank], rank_list[-1] = rank_list[-1], rank_list[self.pp_rank]\n        self._this_rank_models = None\n\n        # store the parameter of each pp stage\n        self.memory_buffers = [None] * self.pp_size\n        for cur_pp_rank in rank_list:\n            print(\n                \"create pp model\",\n                f\"torch allocated {torch.cuda.memory_allocated() / 1e9:.4f} GB, reserved {torch.cuda.memory_reserved() / 1e9:.4f} GB\",\n            )\n            # since the last initialized rank is the current pp rank, after init, the pp rank is still correct\n            mpu.set_pipeline_model_parallel_rank(cur_pp_rank)\n            if cur_pp_rank != self.pp_rank:\n                models = get_model(model_provider, wrap_with_ddp=False, use_distributed_optimizer=False)\n                models = nn.ModuleList(models)\n                assert len(models) == self._model_chunk_size, f\"{len(models)} != {self._model_chunk_size}\"\n                self.pp_models[cur_pp_rank] = models\n            else:\n                # for regular model, we wrapped it with DDP\n                models = get_model(model_provider, wrap_with_ddp=True, use_distributed_optimizer=use_distributed_optimizer)\n                assert len(models) == self._model_chunk_size, f\"{len(models)} != {self._model_chunk_size}\"\n                self._this_rank_models = nn.ModuleList(models)\n                self.pp_models[cur_pp_rank] = nn.ModuleList(unwrap_model(models, (torchDDP, LocalDDP)))\n\n            self._build_param_buffer(cur_pp_rank)\n            self._build_param_references(cur_pp_rank, maintain_weight=cur_pp_rank == self.pp_rank)\n\n            # TODO: after binding to the memory buffer, we can load the checkpoint here\n            if cur_pp_rank != self.pp_rank:\n                for model in self.pp_models[cur_pp_rank]:\n                    model.eval()\n                self._offload_params_to_cpu(cur_pp_rank)\n\n    def _build_param_buffer(self, pp_rank):\n        \"\"\"Build the parameter buffer in each pp rank\"\"\"\n        if pp_rank == self._pp_rank:\n            from siirl.utils.megatron.memory_buffer import MemoryBuffer\n\n            # The code here is very hard-coded, based on the following assumptions:\n            # 1. `len(_this_rank_models) == 1`\n            # 2. `_this_rank_models[0]` is a instance of `DistributedDataParallel` and `use_distributed_optimizer=True`\n            # 3. Only bfloat16 data type is used in parameters\n            source = self._this_rank_models[0].buffers[0].param_data\n            self.memory_buffers[pp_rank] = {torch.bfloat16: MemoryBuffer(source.numel(), source.numel(), torch.bfloat16, source)}\n        else:\n            model = self.pp_models[pp_rank]\n            weight_buffer_meta = get_weight_buffer_meta_from_module(model)\n            self.memory_buffers[pp_rank] = build_memory_buffer(weight_buffer_meta)\n\n    def _build_param_references(self, pp_rank, maintain_weight=False):\n        if pp_rank == self._pp_rank:\n            return\n        model = self.pp_models[pp_rank]\n        build_memory_reference_from_module(model, self.memory_buffers[pp_rank], maintain_weight=maintain_weight)\n\n    def _load_params_to_cuda(self, pp_rank, to_empty=False):\n        assert pp_rank != self.pp_rank, f\"unexpected to load current pp rank [{pp_rank}] back to cuda\"\n        for buffer in self.memory_buffers[pp_rank].values():\n            if not to_empty:\n                buffer.data = buffer.data.to(torch.cuda.current_device(), non_blocking=True)\n            else:\n                buffer.data = torch.empty_like(buffer.data, device=\"cuda\")\n        # rebuild reference after loading to CUDA\n        self._build_param_references(pp_rank)\n\n    def _offload_params_to_cpu(self, pp_rank, to_empty=False):\n        assert pp_rank != self.pp_rank, f\"unexpected to offload current pp rank [{pp_rank}] to cpu\"\n        for buffer in self.memory_buffers[pp_rank].values():\n            if not to_empty:\n                # offload the whole memory buffer to CPU\n                buffer.data = buffer.data.to(\"cpu\", non_blocking=True)\n            else:\n                buffer.data = torch.empty_like(buffer.data, device=\"cpu\")\n        self._build_param_references(pp_rank)\n\n    def load_params_to_cuda(self, to_empty=False):\n        \"\"\"load all model params to cuda\"\"\"\n        for cur_pp_rank in range(self.pp_size):\n            if cur_pp_rank != self.pp_rank:\n                self._load_params_to_cuda(cur_pp_rank, to_empty=to_empty)\n\n    def allgather_params(self):\n        \"\"\"allgather params of all pp ranks. Return a list of handles\"\"\"\n        for cur_pp_rank in range(self.pp_size):\n            global_src = dist.get_global_rank(group=self.pp_group, group_rank=cur_pp_rank)\n\n            # NOTE(sgm): the async op may cause memory leakage of the memory_buffer/pp_models\n\n            for _, param in sorted(self.pp_models[cur_pp_rank].named_parameters()):\n                dist.broadcast(tensor=param.data, src=global_src, group=self.pp_group, async_op=False)\n\n    def forward(self, *inputs, **kwargs):\n        try:\n            prev_output = None\n            for cur_chunk_rank in range(self._model_chunk_size):\n                if self._vpp_size:\n                    mpu.set_virtual_pipeline_model_parallel_rank(cur_chunk_rank)\n\n                for cur_pp_rank in range(self.pp_size):\n                    mpu.set_pipeline_model_parallel_rank(cur_pp_rank)\n                    self.pp_models[cur_pp_rank][cur_chunk_rank].set_input_tensor(prev_output)\n                    ret = self.pp_models[cur_pp_rank][cur_chunk_rank](*inputs, **kwargs)\n                    self.pp_models[cur_pp_rank][cur_chunk_rank].set_input_tensor(None)\n                    prev_output = ret\n        finally:\n            if self._vpp_size:\n                mpu.set_virtual_pipeline_model_parallel_rank(0)\n            mpu.set_pipeline_model_parallel_rank(self.pp_rank)\n        return ret\n\n    def __call__(self, *inputs, **kwargs):\n        return self.forward(*inputs, **kwargs)\n\n    def eval(self):\n        for model in self.pp_models[self.pp_rank]:\n            model.eval()\n\n    def train(self):\n        for model in self.pp_models[self.pp_rank]:\n            model.train()\n\n    def offload_params_to_cpu(self, to_empty=False):\n        \"\"\"offload params of models that are not of current pp rank to cpu\"\"\"\n        for cur_pp_rank in range(self.pp_size):\n            if cur_pp_rank != self.pp_rank:\n                self._offload_params_to_cpu(cur_pp_rank, to_empty=to_empty)\n\n    def get_all_params(self):\n        \"\"\"Get all the parameters of the models in all pp ranks\n\n        Returns:\n            params: List[List[Dict[str, Tensor]]]: a list of parameters in all pp, where each is a list of dict\n                tensors of each model chunk\n\n        \"\"\"\n        params = []\n        for pp_rank in range(self.pp_size):\n            params.append([])\n            for model_chunk_idx in range(len(self.pp_models[pp_rank])):\n                params[pp_rank].append({})\n                pp_model = self.pp_models[pp_rank][model_chunk_idx]\n                pp_model = unwrap_model(pp_model, ((torchDDP, LocalDDP, Float16Module)))  # not use Float16Module\n                for name, param in pp_model.named_parameters():\n                    # NOTE(gh) workaround: should not get lora params for inference\n                    if \"lora\" in name:\n                        continue\n                    params[pp_rank][model_chunk_idx][name] = param\n\n        return params\n\n    def update_this_rank_models(self, new_models):\n        self._this_rank_models = new_models\n        self._pp_models[self.pp_rank] = unwrap_model(new_models, (torchDDP, LocalDDP))\n\n    @property\n    def this_rank_models(self):\n        return self._this_rank_models\n\n    @property\n    def pp_size(self):\n        return self._pp_size\n\n    @property\n    def pp_rank(self):\n        return self._pp_rank\n\n    @property\n    def pp_group(self):\n        return self._pp_group\n\n    @property\n    def pp_models(self):\n        return self._pp_models\n\n\n\"\"\"\nMegatron Hybrid Engine:\n- During training, only the current pp stage holds the parameters\n- Before inference, broadcast the parameters of the current pp rank \n   to all other pp ranks (all pp ranks holds all the parameters)\n- Bind the parameters to the inference engine\n- Do inference in tp. pp is treated as additional dp\n- After inference, all the parameters that doesn't belong to this pp rank is freed.\n\"\"\"\n\n\n# Micro Data parallel group. Micro data parallel group is additional dp group that origins from splitting training tp\n# into infer_tp and micro_tp. By default, we use order micro_dp - tp\n# NOTICE: in new version of vLLM, We need to all-gather all tp rank's model weights\n# For code reuse, we directly assign Megatron's TENSOR_MODEL_PARALLEL_GROUP to this\n_MICRO_DATA_PARALLEL_GROUP = None\n\n\nclass MultiAgentMegatronVLLMShardingManager(BaseShardingManager):\n    @check_device_is_available()\n    def __init__(\n        self,\n        actor_module: nn.ModuleList,\n        inference_engine: LLM,\n        model_config,\n        rollout_config,\n        transformer_config,\n        layer_name_mapping,\n        weight_converter: McoreToHFWeightConverterBase,\n        device_mesh,\n        module: AllGatherPPModel = None,\n        offload_param: bool = False,\n        bridge=None,\n    ):\n        from megatron.core import parallel_state as mpu\n\n        self.device_mesh = device_mesh\n        self.rollout_config = rollout_config\n\n        self.actor_module = actor_module\n        self.inference_engine = inference_engine\n        self.model_config = model_config\n        self.transformer_config = transformer_config\n        self.layer_name_mapping = layer_name_mapping\n        self.weight_converter = weight_converter\n        self.module = module\n        # initialize groups for vllm inference\n        self.rank = torch.distributed.get_rank()\n        self.world_size = torch.distributed.get_world_size()\n        self.infer_tp_size = self.device_mesh[\"infer_tp\"].size()\n\n        self.train_tp_size = mpu.get_tensor_model_parallel_world_size()\n        self.train_tp_rank = mpu.get_tensor_model_parallel_rank()\n        self.train_tp_group = mpu.get_tensor_model_parallel_group()\n        self.train_ep_size = mpu.get_expert_model_parallel_world_size()\n        self.train_ep_rank = mpu.get_expert_model_parallel_rank()\n        self.train_ep_group = mpu.get_expert_model_parallel_group()\n        self.train_etp_size = mpu.get_expert_tensor_parallel_world_size()\n        self.train_etp_rank = mpu.get_expert_tensor_parallel_rank()\n        self.train_etp_group = mpu.get_expert_tensor_parallel_group()\n        self.need_tp_reshard = self.train_tp_size != self.infer_tp_size\n        self.train_tp_larger = self.train_tp_size > self.infer_tp_size\n        self.offload_param = offload_param\n        self.bridge = bridge\n\n        self.torch_random_states = get_torch_device().get_rng_state()\n        if self.device_mesh is not None:\n            gen_dp_rank = self.device_mesh[\"dp\"].get_local_rank()\n            get_torch_device().manual_seed(gen_dp_rank + 1000)  # make sure all tp ranks have the same random states\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)\n        else:\n            self.gen_random_states = None\n\n    @GPUMemoryLogger(role=\"megatron vllm sharding_manager\", logger=logger)\n    def __enter__(self):\n        aggressive_empty_cache(force_sync=True)\n\n        if self.offload_param:\n            load_megatron_model_to_gpu(self.actor_module, load_grad=False)\n        \n        set_expandable_segments(False)\n\n        if self.rollout_config.free_cache_engine:\n            # vllm > 0.7.2\n            if \"tags\" in inspect.signature(self.inference_engine.wake_up).parameters:\n                self.inference_engine.wake_up(tags=[\"weights\"])\n            else:\n                self.inference_engine.wake_up()\n        if self.bridge is not None:\n            per_tensor_param = self.bridge.export_weights(self.actor_module)\n        else:\n            per_tensor_param = per_tensor_generator(\n                self.actor_module,\n                self.model_config,\n                self.weight_converter,\n                self.transformer_config,\n                self.layer_name_mapping,\n            )\n        \n        model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model\n        patch_vllm_moe_model_weight_loader(model)\n        loaded_params = model.load_weights(per_tensor_param)\n        info = f\"vLLM load weights, loaded_params: {len(loaded_params)}\"\n        logger.info(info)\n\n        # (Ping Zhang) Explicitly delete the generator and collected params to free memory\n        del per_tensor_param\n        del loaded_params\n        gc.collect()\n\n        if self.offload_param:\n            offload_megatron_model_to_cpu(self.actor_module)\n        aggressive_empty_cache(force_sync=True)\n        \n        # (vermouth1992) We move wake up kv cache after we release model weights. Need refactor to make API cleaner\n\n        if (\n            self.rollout_config.free_cache_engine \n            and \"tags\" in inspect.signature(self.inference_engine.wake_up).parameters\n        ):\n            self.inference_engine.wake_up(tags=[\"kv_cache\"])\n        \n        # important: need to manually set the random states of each tp to be identical.\n        if self.device_mesh is not None:\n            self.torch_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.gen_random_states)\n\n    @GPUMemoryLogger(role=\"megatron vllm sharding_manager\", logger=logger)\n    def __exit__(self, exc_type, exc_value, traceback):\n        if self.rollout_config.free_cache_engine:\n            self.inference_engine.sleep(level=2)\n        for model in self.actor_module:\n            model.train()\n\n        aggressive_empty_cache(force_sync=True)\n        set_expandable_segments(True)\n\n        # restore random states\n        if self.device_mesh is not None:\n            self.gen_random_states = get_torch_device().get_rng_state()\n            get_torch_device().set_rng_state(self.torch_random_states)"
  },
  {
    "path": "siirl/environment/embodied/__init__.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Embodied AI environment adapters.\"\"\"\n\nfrom .base import BaseVLAEnvironment as BaseEmbodiedEnvironment\nfrom .venv import SubprocVectorEnv\nfrom .adapters.libero import LIBEROAdapter\n\n__all__ = [\n    \"BaseEmbodiedEnvironment\",\n    \"LIBEROAdapter\",\n    \"SubprocVectorEnv\",\n]"
  },
  {
    "path": "siirl/environment/embodied/adapters/__init__.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Embodied environment adapters.\"\"\"\n\nfrom .libero import LIBEROAdapter\n\n__all__ = [\"LIBEROAdapter\"]"
  },
  {
    "path": "siirl/environment/embodied/adapters/libero.py",
    "content": "\n# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport asyncio\nimport random\n\nfrom functools import partial\nfrom loguru import logger\nfrom typing import Any, Dict, List, Tuple, Optional\n\nimport numpy as np\n\nfrom siirl.environment.embodied.base import BaseVLAEnvironment\nfrom siirl.environment.embodied.venv import SubprocVectorEnv\n\ntry:\n    from libero.libero import benchmark, get_libero_path\n    from libero.libero.envs import OffScreenRenderEnv\nexcept ImportError:\n    logger.error(\n        \"Error: LIBERO library not found. Please ensure it is installed correctly.\")\n    exit()\n\n\nclass LIBEROAdapter(BaseVLAEnvironment):\n    \"\"\"\n    An adapter for the LIBERO benchmark suite that wraps it in a vectorized,\n    asynchronous interface conforming to BaseVLAEnvironment.\n\n    This class manages a pool of LIBERO environments running in separate processes,\n    handling task sampling, state initialization, and batched stepping.\n\n    Note: While the interface is `async`, the underlying environment calls are\n    blocking. This implementation uses `asyncio.to_thread` to run blocking\n    I/O without blocking the event loop, making it compatible with async frameworks.\n    \"\"\"\n\n    def __init__(self,\n                 env_name: str,\n                 num_envs: int,\n                 max_steps: int,\n                 num_steps_wait: int = 10,\n                 model_family: str = \"openvla\",\n                 gpu_ids: List[int] = [0],\n                 seed: int = 0):\n        \"\"\"\n        Initializes the LIBERO Adapter.\n\n        Args:\n            env_name (str): The name of the LIBERO task suite to use (e.g., \"libero_10\").\n            num_envs (int): The number of parallel environments to run.\n            num_steps_wait (int): Number of dummy steps to wait for stabilization after reset.\n            model_family (str): The model family, affects action space format.\n            gpu_ids (List[int]): A list of GPU device IDs to distribute environments across.\n            seed (int): The base random seed.\n        \"\"\"\n        logger.info(\n            f\"[LIBEROAdapter] Initializing: suite={env_name}, num_envs={num_envs}, gpu_ids={gpu_ids}\")\n\n        self.task_suite_name = env_name\n        self.env_num = num_envs\n        self.seed = seed\n        self.max_steps = max_steps\n        self.num_steps_wait = num_steps_wait\n        self.model_family = model_family\n        self.gpu_ids = gpu_ids\n\n        self.env: SubprocVectorEnv = None\n        self.step_count = None  # Will be initialized as a fixed-size array on first reset\n\n        self.benchmark_dict = benchmark.get_benchmark_dict()\n        self.task_suite = self.benchmark_dict[self.task_suite_name]()\n\n    def _blocking_reset(self, task_ids: Optional[List[int]] = None, trial_ids: Optional[List[int]] = None) -> List[Dict[str, Any]]:\n        \"\"\"Synchronous implementation of the reset logic.\"\"\"\n\n        logger.debug(\n            f\"[LIBEROAdapter] Reset called: num_tasks={len(task_ids) if task_ids else self.env_num}\"\n        )\n\n        # Use provided task_ids or sample new ones\n        if task_ids is None:\n            logger.warning(\n                f\"[LIBEROAdapter] No task_ids provided, sampling {self.env_num} new tasks\")\n            task_ids = random.sample(\n                range(self.task_suite.n_tasks), self.env_num)\n        else:\n            assert len(\n                task_ids) <= self.env_num, \"Provided task_ids length must less or equal num_envs\"\n        \n        logger.info(f\"[LIBEROAdapter] Resetting {len(task_ids)} environments\")\n        \n        num_active_envs = len(task_ids)\n        active_env_ids = list(range(num_active_envs))\n\n        task_descriptions = []\n        initial_states_list = []\n        env_creators = []\n        resolution = 256\n\n        logger.debug(f\"[LIBEROAdapter] Loading {len(task_ids)} task configurations\")\n        for i, task_id in enumerate(task_ids):\n            task = self.task_suite.get_task(task_id)\n            task_descriptions.append(task.language)\n            task_initial_states = self.task_suite.get_task_init_states(task_id)\n            initial_states_list.append(task_initial_states)\n\n            assigned_gpu = self.gpu_ids[i % len(self.gpu_ids)]\n            env_creators.append(\n                partial(LIBEROAdapter._get_libero_env, task, assigned_gpu, resolution))\n\n        if self.env is None:\n            # First time reset: Always create self.env_num workers to ensure fixed worker pool\n            if len(env_creators) < self.env_num:\n                # Use the first task's env_creator as placeholder for remaining workers\n                placeholder_creator = env_creators[0]\n                env_creators_full = env_creators + [placeholder_creator] * (self.env_num - len(env_creators))\n                logger.info(\n                    f\"[LIBEROAdapter] Created worker pool: {self.env_num} workers total \"\n                    f\"({len(env_creators)} active + {self.env_num - len(env_creators)} placeholder)\"\n                )\n            else:\n                env_creators_full = env_creators\n                logger.info(f\"[LIBEROAdapter] Created worker pool: {self.env_num} workers total (all active)\")\n            \n            self.env = SubprocVectorEnv(env_creators_full)\n        else:\n            # Subsequent resets: Ensure we don't exceed available workers\n            assert len(task_ids) <= self.env_num, \\\n                f\"Cannot reset {len(task_ids)} environments when only {self.env_num} workers exist\"\n            \n            logger.info(f\"[LIBEROAdapter] Reinitializing {len(env_creators)} environments\")\n            self.env.reinit_envs(env_creators, id=active_env_ids)\n\n        logger.debug(f\"[LIBEROAdapter] Resetting {len(active_env_ids)} environments\")\n        # Reset only the active environments.\n        self.env.reset(id=active_env_ids)\n\n        initial_states_to_set = []\n        initial_state_ids = []\n        # Use provided trial_ids or sample new ones\n        if trial_ids is None:\n            logger.debug(f\"[LIBEROAdapter] No trial_ids provided, sampling randomly\")\n            trial_ids = [random.randint(\n                0, len(initial_states_list[i]) - 1) for i in range(len(task_ids))]\n        else:\n            assert len(\n                trial_ids) == len(task_ids), \"Provided trial_ids length must equal task_ids length\"\n\n        for i in range(len(trial_ids)):\n            state_id = trial_ids[i]\n            initial_state_ids.append(state_id)\n            initial_states_to_set.append(initial_states_list[i][state_id])\n\n        logger.debug(f\"[LIBEROAdapter] Setting initial states for {len(trial_ids)} environments\")\n        # Set initial state only for the active environments.\n        obs_np_list = self.env.set_init_state(initial_states_to_set, id=active_env_ids)\n\n        logger.debug(f\"[LIBEROAdapter] Running {self.num_steps_wait} warmup actions\")\n        for _ in range(self.num_steps_wait):\n            dummy_actions = [self._get_dummy_action()\n                            for _ in range(len(trial_ids))]\n            # Step only the active environments.\n            obs_np_list, _, _, _ = self.env.step(dummy_actions, id=active_env_ids)\n\n        # Initialize or reset step_count for the fixed-size worker pool\n        if self.step_count is None:\n            self.step_count = np.zeros(self.env_num, dtype=int)\n            logger.debug(f\"[LIBEROAdapter] Initialized step_count tracking (size={self.env_num})\")\n        else:\n            # Reset step count for the active environments only\n            self.step_count[active_env_ids] = 0\n\n        results = []\n        for i in range(len(task_ids)):\n            task_id = task_ids[i]\n            trial_id = initial_state_ids[i]\n            results.append({\n                'type': 'init',\n                'obs': obs_np_list[i],\n                \"task_description\": task_descriptions[i],\n                'valid_images': [obs_np_list[i][\"agentview_image\"][::-1, ::-1]],\n                'task_file_name': f\"{self.task_suite_name}_task_{task_id}_trial_{trial_id}\",\n                'active': True,\n                'complete': False,\n                'finish_step': 0\n            })\n        \n        logger.info(f\"[LIBEROAdapter] Reset completed, returned {len(results)} results\")\n        return results\n\n    async def reset(self, task_ids: Optional[List[int]] = None, trial_ids: Optional[List[int]] = None) -> List[Dict[str, Any]]:\n        \"\"\"Asynchronously resets all parallel environments.\"\"\"\n        return await asyncio.to_thread(self._blocking_reset, task_ids=task_ids, trial_ids=trial_ids)\n\n    def _blocking_step(self, action: Dict[str, Any]) -> List[Dict[str, Any]]:\n        \"\"\"Synchronous implementation of the step logic for an action chunk.\"\"\"\n\n        actions = action[\"actions\"]\n        active_indices_set = set(action[\"indices\"])\n        \n        batch_size = actions.shape[0]\n        results = [None] * batch_size\n        step_images = [None] * batch_size\n        \n        active_indices_list = sorted(list(active_indices_set))\n\n        for j in range(actions.shape[1]):\n            normalized_actions = []\n            active_indices_list = sorted(list(active_indices_set))\n            if len(active_indices_list) == 0:\n                break\n            for act_idx in active_indices_list:\n                try:\n                    single_action = actions[act_idx][j]\n                except Exception as e:\n                    logger.error(f\"[LIBEROAdapter] Failed to access action[{act_idx}][{j}]: {e}\")\n                    raise\n                normalized_action = self._normalize_gripper_action(\n                    single_action, binarize=True)\n                inverted_action = self._invert_gripper_action(normalized_action)\n                normalized_actions.append(inverted_action.tolist())\n\n            step_return = self.env.step(normalized_actions, active_indices_list)\n\n            if len(step_return) == 4:\n                obs, rew, dones, infos = step_return\n            else:  # new API\n                obs, rew, terminateds, truncateds, infos = step_return\n                dones = np.logical_or(terminateds, truncateds)\n\n            self.step_count[active_indices_list] += 1\n\n            for i in range(len(active_indices_list)):\n                act_idx = active_indices_list[i]\n                if step_images[act_idx] is None:\n                    step_images[act_idx] = []\n                step_images[act_idx].append(obs[i][\"agentview_image\"][::-1, ::-1])\n            \n                if dones[i] or self.step_count[act_idx] >= self.max_steps:\n                    # Only log when task is truly completed (not just max_steps reached)\n                    if dones[i]:\n                        logger.info(\n                            f\"[LIBEROAdapter] Environment {act_idx} completed successfully: \"\n                            f\"total_steps={self.step_count[act_idx]}\"\n                        )\n                    results[act_idx] = {\n                        'type': 'step',\n                        'obs': obs[i],\n                        'active': False,\n                        'complete': dones[i],\n                        'finish_step': self.step_count[act_idx],\n                        'valid_images': step_images[act_idx]\n                    }\n                    active_indices_set.remove(act_idx)\n\n        for i in range(len(active_indices_list)):\n            act_idx = active_indices_list[i]\n            if results[act_idx] is None:\n                results[act_idx] = {\n                    'type': 'step',\n                    'obs': obs[i],\n                    'active': not(dones[i] or self.step_count[act_idx] >= self.max_steps),\n                    'complete': dones[i],\n                    'finish_step': self.step_count[act_idx],\n                    'valid_images': step_images[act_idx]\n                }\n\n        return results\n\n    async def step(self, action: Dict[str, Any]) -> List[Dict[str, Any]]:\n        \"\"\"\n        Asynchronously steps all parallel environments.\n        Note: The return types are batched for vectorized operation.\n        \"\"\"\n        return await asyncio.to_thread(self._blocking_step, action)\n\n    def close(self):\n        \"\"\"Closes all environments and shuts down subprocesses.\"\"\"\n        logger.info(\"[LIBEROAdapter] Closing environment worker pool\")\n        if self.env is not None:\n            self.env.close()\n\n    @staticmethod\n    def _get_libero_env(task, gpu_id, resolution=256):\n        \"\"\"Initializes and returns the LIBERO environment.\"\"\"\n        task_bddl_file = os.path.join(get_libero_path(\n            \"bddl_files\"), task.problem_folder, task.bddl_file)\n        env_args = {\n            \"bddl_file_name\": task_bddl_file,\n            \"camera_heights\": resolution,\n            \"camera_widths\": resolution,\n            \"render_gpu_device_id\": gpu_id\n        }\n        env = OffScreenRenderEnv(**env_args)\n        # IMPORTANT: seed seems to affect object positions even when using fixed initial state\n        env.seed(0)\n        return env\n\n    def _get_dummy_action(self) -> List[float]:\n        \"\"\"Returns a neutral or no-op action for the specified model family.\"\"\"\n        return [0, 0, 0, 0, 0, 0, -1]\n\n    def _normalize_gripper_action(self, action: np.ndarray, binarize: bool = True) -> np.ndarray:\n        \"\"\"\n        Normalize gripper action from [0,1] to [-1,+1] range.\n        This is necessary for some environments because the dataset wrapper\n        standardizes gripper actions to [0,1]. Note that unlike the other action\n        dimensions, the gripper action is not normalized to [-1,+1] by default.\n        Normalization formula: y = 2 * (x - orig_low) / (orig_high - orig_low) - 1\n        Args:\n            action: Action array with gripper action in the last dimension\n            binarize: Whether to binarize gripper action to -1 or +1\n        Returns:\n            np.ndarray: Action array with normalized gripper action\n        \"\"\"\n        # Create a copy to avoid modifying the original\n        normalized_action = action.copy()\n        # Normalize the last action dimension to [-1,+1]\n        orig_low, orig_high = 0.0, 1.0\n        normalized_action[..., -1] = 2 * \\\n            (normalized_action[..., -1] - orig_low) / \\\n            (orig_high - orig_low) - 1\n        if binarize:\n            # Binarize to -1 or +1\n            normalized_action[..., -1] = np.sign(normalized_action[..., -1])\n        return normalized_action\n\n    def _invert_gripper_action(self, action: np.ndarray) -> np.ndarray:\n        \"\"\"\n        Flip the sign of the gripper action (last dimension of action vector).\n        This is necessary for environments where -1 = open, +1 = close, since\n        the RLDS dataloader aligns gripper actions such that 0 = close, 1 = open.\n        Args:\n            action: Action array with gripper action in the last dimension\n        Returns:\n            np.ndarray: Action array with inverted gripper action\n        \"\"\"\n        # Create a copy to avoid modifying the original\n        inverted_action = action.copy()\n        # Invert the gripper action\n        inverted_action[..., -1] = inverted_action[..., -1] * -1.0\n        return inverted_action\n"
  },
  {
    "path": "siirl/environment/embodied/base.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# -*- coding: utf-8 -*-\n\n\"\"\" \nThis module defines an abstract base class for a Vision-Language-Action (VLA) environment\n\"\"\"\n\nfrom abc import ABC, abstractmethod\nfrom typing import Any, Dict, Tuple\n\nclass BaseVLAEnvironment(ABC):\n    \"\"\"\n    Abstract Base Class for a Vision-Language-Action (VLA) environment.\n    It defines the standard asynchronous interface for resetting the environment\n    and stepping through it.\n    \"\"\"\n\n    @abstractmethod\n    async def reset(self) -> Dict[str, Any]:\n        \"\"\"\n        Resets the environment to an initial state.\n\n        Returns:\n            Dict[str, Any]: The initial multi-modal observation,\n                            e.g., {\"image\": np.array, \"text\": \"task prompt\"}.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    async def step(self, action: Dict[str, Any]) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]:\n        \"\"\"\n        Runs one timestep of the environment's dynamics.\n\n        Args:\n            action (Dict[str, Any]): A dictionary containing the action to be executed.\n                                     For example, {\"continuous_action\": np.array([...])}.\n\n        Returns:\n            Tuple[Dict, float, bool, bool, Dict]: A tuple containing:\n                - observation (Dict): The next observation.\n                - reward (float): The reward received.\n                - terminated (bool): Whether the episode has ended.\n                - truncated (bool): Whether the episode was truncated.\n                - info (Dict): Auxiliary diagnostic information.\n        \"\"\"\n        pass\n"
  },
  {
    "path": "siirl/environment/embodied/venv.py",
    "content": "# Modified from https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/envs/venv.py\n\nimport cloudpickle\nimport ctypes\nimport gymnasium as gym\nimport numpy as np\nimport warnings\nimport time\n\nfrom abc import ABC, abstractmethod\nfrom collections import OrderedDict\nimport multiprocessing\nfrom multiprocessing import Array, Pipe, connection\nfrom multiprocessing.context import Process\nfrom typing import Any, Callable, List, Optional, Tuple, Union\n\n\n# Ref: https://github.com/Lifelong-Robot-Learning/LIBERO/issues/3\n# TODO: test it\nif multiprocessing.get_start_method(allow_none=True) != \"spawn\":  \n    multiprocessing.set_start_method(\"spawn\", force=True)\n\n\ngym_old_venv_step_type = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]\ngym_new_venv_step_type = Tuple[\n    np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray\n]\nwarnings.simplefilter(\"once\", DeprecationWarning)\n_NP_TO_CT = {\n    np.bool_: ctypes.c_bool,\n    np.uint8: ctypes.c_uint8,\n    np.uint16: ctypes.c_uint16,\n    np.uint32: ctypes.c_uint32,\n    np.uint64: ctypes.c_uint64,\n    np.int8: ctypes.c_int8,\n    np.int16: ctypes.c_int16,\n    np.int32: ctypes.c_int32,\n    np.int64: ctypes.c_int64,\n    np.float32: ctypes.c_float,\n    np.float64: ctypes.c_double,\n}\n\n\ndef deprecation(msg: str) -> None:\n    \"\"\"Deprecation warning wrapper.\"\"\"\n    warnings.warn(msg, category=DeprecationWarning, stacklevel=2)\n\n\nclass CloudpickleWrapper(object):\n    \"\"\"A cloudpickle wrapper used in SubprocVectorEnv.\"\"\"\n\n    def __init__(self, data: Any) -> None:\n        self.data = data\n\n    def __getstate__(self) -> str:\n        return cloudpickle.dumps(self.data)\n\n    def __setstate__(self, data: str) -> None:\n        self.data = cloudpickle.loads(data)\n\n\nGYM_RESERVED_KEYS = [\n    \"metadata\",\n    \"reward_range\",\n    \"spec\",\n    \"action_space\",\n    \"observation_space\",\n]\n\n\n################################################################################\n#\n# Workers\n#\n################################################################################\n\n\nclass EnvWorker(ABC):\n    \"\"\"An abstract worker for an environment.\"\"\"\n\n    def __init__(self, env_fn: Callable[[], gym.Env]) -> None:\n        self._env_fn = env_fn\n        self.is_closed = False\n        self.result: Union[\n            gym_old_venv_step_type,\n            gym_new_venv_step_type,\n            Tuple[np.ndarray, dict],\n            np.ndarray,\n        ]\n        # self.action_space = self.get_env_attr(\"action_space\")  # noqa: B009\n        self.is_reset = False\n\n    @abstractmethod\n    def get_env_attr(self, key: str) -> Any:\n        pass\n\n    @abstractmethod\n    def set_env_attr(self, key: str, value: Any) -> None:\n        pass\n\n    def send(self, action: Optional[np.ndarray]) -> None:\n        \"\"\"Send action signal to low-level worker.\n\n        When action is None, it indicates sending \"reset\" signal; otherwise\n        it indicates \"step\" signal. The paired return value from \"recv\"\n        function is determined by such kind of different signal.\n        \"\"\"\n        if hasattr(self, \"send_action\"):\n            deprecation(\n                \"send_action will soon be deprecated. \"\n                \"Please use send and recv for your own EnvWorker.\"\n            )\n            if action is None:\n                self.is_reset = True\n                self.result = self.reset()\n            else:\n                self.is_reset = False\n                self.send_action(action)\n\n    def recv(\n        self,\n    ) -> Union[\n        gym_old_venv_step_type,\n        gym_new_venv_step_type,\n        Tuple[np.ndarray, dict],\n        np.ndarray,\n    ]:  # noqa:E125\n        \"\"\"Receive result from low-level worker.\n\n        If the last \"send\" function sends a NULL action, it only returns a\n        single observation; otherwise it returns a tuple of (obs, rew, done,\n        info) or (obs, rew, terminated, truncated, info), based on whether\n        the environment is using the old step API or the new one.\n        \"\"\"\n        if hasattr(self, \"get_result\"):\n            deprecation(\n                \"get_result will soon be deprecated. \"\n                \"Please use send and recv for your own EnvWorker.\"\n            )\n            if not self.is_reset:\n                self.result = self.get_result()\n        return self.result\n\n    @abstractmethod\n    def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:\n        pass\n\n    def step(\n        self, action: np.ndarray\n    ) -> Union[gym_old_venv_step_type, gym_new_venv_step_type]:\n        \"\"\"Perform one timestep of the environment's dynamic.\n\n        \"send\" and \"recv\" are coupled in sync simulation, so users only call\n        \"step\" function. But they can be called separately in async\n        simulation, i.e. someone calls \"send\" first, and calls \"recv\" later.\n        \"\"\"\n        self.send(action)\n        return self.recv()  # type: ignore\n\n    @staticmethod\n    def wait(\n        workers: List[\"EnvWorker\"], wait_num: int, timeout: Optional[float] = None\n    ) -> List[\"EnvWorker\"]:\n        \"\"\"Given a list of workers, return those ready ones.\"\"\"\n        raise NotImplementedError\n\n    def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:\n        # return self.action_space.seed(seed)  # issue 299\n        pass\n\n    @abstractmethod\n    def render(self, **kwargs: Any) -> Any:\n        \"\"\"Render the environment.\"\"\"\n        pass\n\n    @abstractmethod\n    def close_env(self) -> None:\n        pass\n\n    def close(self) -> None:\n        if self.is_closed:\n            return None\n        self.is_closed = True\n        self.close_env()\n\n\nclass ShArray:\n    \"\"\"Wrapper of multiprocessing Array.\"\"\"\n\n    def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:\n        self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))  # type: ignore\n        self.dtype = dtype\n        self.shape = shape\n\n    def save(self, ndarray: np.ndarray) -> None:\n        assert isinstance(ndarray, np.ndarray)\n        dst = self.arr.get_obj()\n        dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(\n            self.shape\n        )  # type: ignore\n        np.copyto(dst_np, ndarray)\n\n    def get(self) -> np.ndarray:\n        obj = self.arr.get_obj()\n        return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape)  # type: ignore\n\n\ndef _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:\n    if isinstance(space, gym.spaces.Dict):\n        assert isinstance(space.spaces, OrderedDict)\n        return {k: _setup_buf(v) for k, v in space.spaces.items()}\n    elif isinstance(space, gym.spaces.Tuple):\n        assert isinstance(space.spaces, tuple)\n        return tuple([_setup_buf(t) for t in space.spaces])\n    else:\n        return ShArray(space.dtype, space.shape)  # type: ignore\n\n\ndef _worker(\n    parent: connection.Connection,\n    p: connection.Connection,\n    env_fn_wrapper: CloudpickleWrapper,\n    obs_bufs: Optional[Union[dict, tuple, ShArray]] = None,\n) -> None:\n    def _encode_obs(\n        obs: Union[dict, tuple, np.ndarray], buffer: Union[dict, tuple, ShArray]\n    ) -> None:\n        if isinstance(obs, np.ndarray) and isinstance(buffer, ShArray):\n            buffer.save(obs)\n        elif isinstance(obs, tuple) and isinstance(buffer, tuple):\n            for o, b in zip(obs, buffer):\n                _encode_obs(o, b)\n        elif isinstance(obs, dict) and isinstance(buffer, dict):\n            for k in obs.keys():\n                _encode_obs(obs[k], buffer[k])\n        return None\n\n    parent.close()\n    env = env_fn_wrapper.data()\n    try:\n        while True:\n            try:\n                cmd, data = p.recv()\n            except EOFError:  # the pipe has been closed\n                p.close()\n                break\n            if cmd == \"step\":\n                env_return = env.step(data)\n                if obs_bufs is not None:\n                    _encode_obs(env_return[0], obs_bufs)\n                    env_return = (None, *env_return[1:])\n                p.send(env_return)\n            elif cmd == \"reinit_env\":\n                env.close()\n                env = data.data() # data is a CloudpickleWrapper with the new env_fn\n                p.send(True)\n            elif cmd == \"reset\":\n                retval = env.reset(**data)\n                reset_returns_info = (\n                    isinstance(retval, (tuple, list))\n                    and len(retval) == 2\n                    and isinstance(retval[1], dict)\n                )\n                if reset_returns_info:\n                    obs, info = retval\n                else:\n                    obs = retval\n                if obs_bufs is not None:\n                    _encode_obs(obs, obs_bufs)\n                    obs = None\n                if reset_returns_info:\n                    p.send((obs, info))\n                else:\n                    p.send(obs)\n            elif cmd == \"close\":\n                p.send(env.close())\n                p.close()\n                break\n            elif cmd == \"render\":\n                p.send(env.render(**data) if hasattr(env, \"render\") else None)\n            elif cmd == \"seed\":\n                if hasattr(env, \"seed\"):\n                    p.send(env.seed(data))\n                else:\n                    env.reset(seed=data)\n                    p.send(None)\n            elif cmd == \"getattr\":\n                p.send(getattr(env, data) if hasattr(env, data) else None)\n            elif cmd == \"setattr\":\n                setattr(env.unwrapped, data[\"key\"], data[\"value\"])\n            elif cmd == \"check_success\":\n                p.send(env.check_success())\n            elif cmd == \"get_segmentation_of_interest\":\n                p.send(env.get_segmentation_of_interest(data))\n            elif cmd == \"get_sim_state\":\n                p.send(env.get_sim_state())\n            elif cmd == \"set_init_state\":\n                obs = env.set_init_state(data)\n                p.send(obs)\n            else:\n                p.close()\n                raise NotImplementedError\n    except KeyboardInterrupt:\n        p.close()\n\n\nclass DummyEnvWorker(EnvWorker):\n    \"\"\"Dummy worker used in sequential vector environments.\"\"\"\n\n    def __init__(self, env_fn: Callable[[], gym.Env]) -> None:\n        self.env = env_fn()\n        super().__init__(env_fn)\n\n    def get_env_attr(self, key: str) -> Any:\n        return getattr(self.env, key)\n\n    def set_env_attr(self, key: str, value: Any) -> None:\n        setattr(self.env.unwrapped, key, value)\n\n    def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:\n        if \"seed\" in kwargs:\n            super().seed(kwargs[\"seed\"])\n        return self.env.reset(**kwargs)\n\n    @staticmethod\n    def wait(  # type: ignore\n        workers: List[\"DummyEnvWorker\"], wait_num: int, timeout: Optional[float] = None\n    ) -> List[\"DummyEnvWorker\"]:\n        # Sequential EnvWorker objects are always ready\n        return workers\n\n    def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:\n        if action is None:\n            self.result = self.env.reset(**kwargs)\n        else:\n            self.result = self.env.step(action)  # type: ignore\n\n    def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:\n        super().seed(seed)\n        try:\n            return self.env.seed(seed)  # type: ignore\n        except (AttributeError, NotImplementedError):\n            self.env.reset(seed=seed)\n            return [seed]  # type: ignore\n\n    def render(self, **kwargs: Any) -> Any:\n        return self.env.render(**kwargs)\n\n    def close_env(self) -> None:\n        self.env.close()\n\n    def check_success(self):\n        return self.env.check_success()\n\n    def get_segmentation_of_interest(self, segmentation_image):\n        return self.env.get_segmentation_of_interest(segmentation_image)\n\n    def get_sim_state(self):\n        return self.env.get_sim_state()\n\n    def set_init_state(self, init_state):\n        return self.env.set_init_state(init_state)\n\n\nclass SubprocEnvWorker(EnvWorker):\n    \"\"\"Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv.\"\"\"\n\n    def __init__(\n        self, env_fn: Callable[[], gym.Env], share_memory: bool = False\n    ) -> None:\n        self.parent_remote, self.child_remote = Pipe()\n        self.share_memory = share_memory\n        self.buffer: Optional[Union[dict, tuple, ShArray]] = None\n        if self.share_memory:\n            dummy = env_fn()\n            obs_space = dummy.observation_space\n            dummy.close()\n            del dummy\n            self.buffer = _setup_buf(obs_space)\n        args = (\n            self.parent_remote,\n            self.child_remote,\n            CloudpickleWrapper(env_fn),\n            self.buffer,\n        )\n        self.process = Process(target=_worker, args=args, daemon=True)\n        self.process.start()\n        self.child_remote.close()\n        super().__init__(env_fn)\n\n    def get_env_attr(self, key: str) -> Any:\n        self.parent_remote.send([\"getattr\", key])\n        return self.parent_remote.recv()\n\n    def set_env_attr(self, key: str, value: Any) -> None:\n        self.parent_remote.send([\"setattr\", {\"key\": key, \"value\": value}])\n\n    def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:\n        def decode_obs(\n            buffer: Optional[Union[dict, tuple, ShArray]]\n        ) -> Union[dict, tuple, np.ndarray]:\n            if isinstance(buffer, ShArray):\n                return buffer.get()\n            elif isinstance(buffer, tuple):\n                return tuple([decode_obs(b) for b in buffer])\n            elif isinstance(buffer, dict):\n                return {k: decode_obs(v) for k, v in buffer.items()}\n            else:\n                raise NotImplementedError\n\n        return decode_obs(self.buffer)\n\n    @staticmethod\n    def wait(  # type: ignore\n        workers: List[\"SubprocEnvWorker\"],\n        wait_num: int,\n        timeout: Optional[float] = None,\n    ) -> List[\"SubprocEnvWorker\"]:\n        remain_conns = conns = [x.parent_remote for x in workers]\n        ready_conns: List[connection.Connection] = []\n        remain_time, t1 = timeout, time.time()\n        while len(remain_conns) > 0 and len(ready_conns) < wait_num:\n            if timeout:\n                remain_time = timeout - (time.time() - t1)\n                if remain_time <= 0:\n                    break\n            # connection.wait hangs if the list is empty\n            new_ready_conns = connection.wait(remain_conns, timeout=remain_time)\n            ready_conns.extend(new_ready_conns)  # type: ignore\n            remain_conns = [conn for conn in remain_conns if conn not in ready_conns]\n        return [workers[conns.index(con)] for con in ready_conns]\n        \n    def send_reinit(self, env_fn: Callable[[], gym.Env]) -> None:\n        \"\"\"\n        Sends a reinit command without waiting for a response.\n        This allows for parallel initialization.\n        \"\"\"\n        self.parent_remote.send((\"reinit_env\", CloudpickleWrapper(env_fn)))\n\n    def recv_reinit(self) -> bool:\n        \"\"\"\n        Waits for and receives the confirmation from a reinit command.\n        \"\"\"\n        try:\n            return self.parent_remote.recv()\n        except (BrokenPipeError, EOFError):\n            return False\n\n    def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:\n        if action is None:\n            if \"seed\" in kwargs:\n                super().seed(kwargs[\"seed\"])\n            self.parent_remote.send([\"reset\", kwargs])\n        else:\n            self.parent_remote.send([\"step\", action])\n\n    def recv(\n        self,\n    ) -> Union[\n        gym_old_venv_step_type,\n        gym_new_venv_step_type,\n        Tuple[np.ndarray, dict],\n        np.ndarray,\n    ]:  # noqa:E125\n        result = self.parent_remote.recv()\n        if isinstance(result, tuple):\n            if len(result) == 2:\n                obs, info = result\n                if self.share_memory:\n                    obs = self._decode_obs()\n                return obs, info\n            obs = result[0]\n            if self.share_memory:\n                obs = self._decode_obs()\n            return (obs, *result[1:])  # type: ignore\n        else:\n            obs = result\n            if self.share_memory:\n                obs = self._decode_obs()\n            return obs\n\n    def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:\n        if \"seed\" in kwargs:\n            super().seed(kwargs[\"seed\"])\n        self.parent_remote.send([\"reset\", kwargs])\n\n        result = self.parent_remote.recv()\n        if isinstance(result, tuple):\n            obs, info = result\n            if self.share_memory:\n                obs = self._decode_obs()\n            return obs, info\n        else:\n            obs = result\n            if self.share_memory:\n                obs = self._decode_obs()\n            return obs\n\n    def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:\n        super().seed(seed)\n        self.parent_remote.send([\"seed\", seed])\n        ret = self.parent_remote.recv()\n        return ret\n\n    def render(self, **kwargs: Any) -> Any:\n        self.parent_remote.send([\"render\", kwargs])\n        return self.parent_remote.recv()\n\n    def close_env(self) -> None:\n        try:\n            self.parent_remote.send([\"close\", None])\n            # mp may be deleted so it may raise AttributeError\n            self.parent_remote.recv()\n            self.process.join()\n        except (BrokenPipeError, EOFError, AttributeError):\n            pass\n        # ensure the subproc is terminated\n        self.process.terminate()\n\n    def check_success(self):\n        self.parent_remote.send([\"check_success\", None])\n        return self.parent_remote.recv()\n\n    def get_segmentation_of_interest(self, segmentation_image):\n        self.parent_remote.send([\"get_segmentation_of_interest\", segmentation_image])\n        return self.parent_remote.recv()\n\n    def get_sim_state(self):\n        self.parent_remote.send([\"get_sim_state\", None])\n        return self.parent_remote.recv()\n\n    def set_init_state(self, init_state):\n        self.parent_remote.send([\"set_init_state\", init_state])\n        obs = self.parent_remote.recv()\n        if self.share_memory:\n            obs = self._decode_obs()\n        return obs\n\n\n################################################################################\n#\n# VecEnvs\n#\n################################################################################\n\n\nclass BaseVectorEnv(object):\n    \"\"\"Base class for vectorized environments.\n\n    Usage:\n    ::\n\n        env_num = 8\n        envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(env_num)])\n        assert len(envs) == env_num\n\n    It accepts a list of environment generators. In other words, an environment\n    generator ``efn`` of a specific task means that ``efn()`` returns the\n    environment of the given task, for example, ``gym.make(task)``.\n\n    All of the VectorEnv must inherit :class:`~tianshou.env.BaseVectorEnv`.\n    Here are some other usages:\n    ::\n\n        envs.seed(2)  # which is equal to the next line\n        envs.seed([2, 3, 4, 5, 6, 7, 8, 9])  # set specific seed for each env\n        obs = envs.reset()  # reset all environments\n        obs = envs.reset([0, 5, 7])  # reset 3 specific environments\n        obs, rew, done, info = envs.step([1] * 8)  # step synchronously\n        envs.render()  # render all environments\n        envs.close()  # close all environments\n\n    .. warning::\n\n        If you use your own environment, please make sure the ``seed`` method\n        is set up properly, e.g.,\n        ::\n\n            def seed(self, seed):\n                np.random.seed(seed)\n\n        Otherwise, the outputs of these envs may be the same with each other.\n\n    :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the i-th env.\n    :param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a\n        worker which contains the i-th env.\n    :param int wait_num: use in asynchronous simulation if the time cost of\n        ``env.step`` varies with time and synchronously waiting for all\n        environments to finish a step is time-wasting. In that case, we can\n        return when ``wait_num`` environments finish a step and keep on\n        simulation in these environments. If ``None``, asynchronous simulation\n        is disabled; else, ``1 <= wait_num <= env_num``.\n    :param float timeout: use in asynchronous simulation same as above, in each\n        vectorized step it only deal with those environments spending time\n        within ``timeout`` seconds.\n    \"\"\"\n\n    def __init__(\n        self,\n        env_fns: List[Callable[[], gym.Env]],\n        worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker],\n        wait_num: Optional[int] = None,\n        timeout: Optional[float] = None,\n    ) -> None:\n        self._env_fns = env_fns\n        # A VectorEnv contains a pool of EnvWorkers, which corresponds to\n        # interact with the given envs (one worker <-> one env).\n        self.workers = [worker_fn(fn) for fn in env_fns]\n        self.worker_class = type(self.workers[0])\n        assert issubclass(self.worker_class, EnvWorker)\n        assert all([isinstance(w, self.worker_class) for w in self.workers])\n\n        self.env_num = len(env_fns)\n        self.wait_num = wait_num or len(env_fns)\n        assert (\n            1 <= self.wait_num <= len(env_fns)\n        ), f\"wait_num should be in [1, {len(env_fns)}], but got {wait_num}\"\n        self.timeout = timeout\n        assert (\n            self.timeout is None or self.timeout > 0\n        ), f\"timeout is {timeout}, it should be positive if provided!\"\n        self.is_async = self.wait_num != len(env_fns) or timeout is not None\n        self.waiting_conn: List[EnvWorker] = []\n        # environments in self.ready_id is actually ready\n        # but environments in self.waiting_id are just waiting when checked,\n        # and they may be ready now, but this is not known until we check it\n        # in the step() function\n        self.waiting_id: List[int] = []\n        # all environments are ready in the beginning\n        self.ready_id = list(range(self.env_num))\n        self.is_closed = False\n\n    def _assert_is_not_closed(self) -> None:\n        assert (\n            not self.is_closed\n        ), f\"Methods of {self.__class__.__name__} cannot be called after close.\"\n\n    def __len__(self) -> int:\n        \"\"\"Return len(self), which is the number of environments.\"\"\"\n        return self.env_num\n\n    def __getattribute__(self, key: str) -> Any:\n        \"\"\"Switch the attribute getter depending on the key.\n\n        Any class who inherits ``gym.Env`` will inherit some attributes, like\n        ``action_space``. However, we would like the attribute lookup to go straight\n        into the worker (in fact, this vector env's action_space is always None).\n        \"\"\"\n        if key in GYM_RESERVED_KEYS:  # reserved keys in gym.Env\n            return self.get_env_attr(key)\n        else:\n            return super().__getattribute__(key)\n\n    def get_env_attr(\n        self,\n        key: str,\n        id: Optional[Union[int, List[int], np.ndarray]] = None,\n    ) -> List[Any]:\n        \"\"\"Get an attribute from the underlying environments.\n\n        If id is an int, retrieve the attribute denoted by key from the environment\n        underlying the worker at index id. The result is returned as a list with one\n        element. Otherwise, retrieve the attribute for all workers at indices id and\n        return a list that is ordered correspondingly to id.\n\n        :param str key: The key of the desired attribute.\n        :param id: Indice(s) of the desired worker(s). Default to None for all env_id.\n\n        :return list: The list of environment attributes.\n        \"\"\"\n        self._assert_is_not_closed()\n        id = self._wrap_id(id)\n        if self.is_async:\n            self._assert_id(id)\n\n        return [self.workers[j].get_env_attr(key) for j in id]\n\n    def set_env_attr(\n        self,\n        key: str,\n        value: Any,\n        id: Optional[Union[int, List[int], np.ndarray]] = None,\n    ) -> None:\n        \"\"\"Set an attribute in the underlying environments.\n\n        If id is an int, set the attribute denoted by key from the environment\n        underlying the worker at index id to value.\n        Otherwise, set the attribute for all workers at indices id.\n\n        :param str key: The key of the desired attribute.\n        :param Any value: The new value of the attribute.\n        :param id: Indice(s) of the desired worker(s). Default to None for all env_id.\n        \"\"\"\n        self._assert_is_not_closed()\n        id = self._wrap_id(id)\n        if self.is_async:\n            self._assert_id(id)\n        for j in id:\n            self.workers[j].set_env_attr(key, value)\n\n    def _wrap_id(\n        self,\n        id: Optional[Union[int, List[int], np.ndarray]] = None,\n    ) -> Union[List[int], np.ndarray]:\n        if id is None:\n            return list(range(self.env_num))\n        return [id] if np.isscalar(id) else id  # type: ignore\n\n    def _assert_id(self, id: Union[List[int], np.ndarray]) -> None:\n        for i in id:\n            assert (\n                i not in self.waiting_id\n            ), f\"Cannot interact with environment {i} which is stepping now.\"\n            assert (\n                i in self.ready_id\n            ), f\"Can only interact with ready environments {self.ready_id}.\"\n\n    def reset(\n        self,\n        id: Optional[Union[int, List[int], np.ndarray]] = None,\n        **kwargs: Any,\n    ) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:\n        \"\"\"Reset the state of some envs and return initial observations.\n\n        If id is None, reset the state of all the environments and return\n        initial observations, otherwise reset the specific environments with\n        the given id, either an int or a list.\n        \"\"\"\n        self._assert_is_not_closed()\n        id = self._wrap_id(id)\n        if self.is_async:\n            self._assert_id(id)\n\n        # send(None) == reset() in worker\n        for i in id:\n            self.workers[i].send(None, **kwargs)\n        ret_list = [self.workers[i].recv() for i in id]\n\n        reset_returns_info = (\n            isinstance(ret_list[0], (tuple, list))\n            and len(ret_list[0]) == 2\n            and isinstance(ret_list[0][1], dict)\n        )\n        if reset_returns_info:\n            obs_list = [r[0] for r in ret_list]\n        else:\n            obs_list = ret_list\n\n        if isinstance(obs_list[0], tuple):\n            raise TypeError(\n                \"Tuple observation space is not supported. \",\n                \"Please change it to array or dict space\",\n            )\n        try:\n            obs = np.stack(obs_list)\n        except ValueError:  # different len(obs)\n            obs = np.array(obs_list, dtype=object)\n\n        if reset_returns_info:\n            infos = [r[1] for r in ret_list]\n            return obs, infos  # type: ignore\n        else:\n            return obs\n\n    def step(\n        self,\n        action: np.ndarray,\n        id: Optional[Union[int, List[int], np.ndarray]] = None,\n    ) -> Union[gym_old_venv_step_type, gym_new_venv_step_type]:\n        \"\"\"Run one timestep of some environments' dynamics.\n\n        If id is None, run one timestep of all the environments’ dynamics;\n        otherwise run one timestep for some environments with given id,  either\n        an int or a list. When the end of episode is reached, you are\n        responsible for calling reset(id) to reset this environment’s state.\n\n        Accept a batch of action and return a tuple (batch_obs, batch_rew,\n        batch_done, batch_info) in numpy format.\n\n        :param numpy.ndarray action: a batch of action provided by the agent.\n\n        :return: A tuple consisting of either:\n\n            * ``obs`` a numpy.ndarray, the agent's observation of current environments\n            * ``rew`` a numpy.ndarray, the amount of rewards returned after \\\n                previous actions\n            * ``done`` a numpy.ndarray, whether these episodes have ended, in \\\n                which case further step() calls will return undefined results\n            * ``info`` a numpy.ndarray, contains auxiliary diagnostic \\\n                information (helpful for debugging, and sometimes learning)\n\n            or:\n\n            * ``obs`` a numpy.ndarray, the agent's observation of current environments\n            * ``rew`` a numpy.ndarray, the amount of rewards returned after \\\n                previous actions\n            * ``terminated`` a numpy.ndarray, whether these episodes have been \\\n                terminated\n            * ``truncated`` a numpy.ndarray, whether these episodes have been truncated\n            * ``info`` a numpy.ndarray, contains auxiliary diagnostic \\\n                information (helpful for debugging, and sometimes learning)\n\n            The case distinction is made based on whether the underlying environment\n            uses the old step API (first case) or the new step API (second case).\n\n        For the async simulation:\n\n        Provide the given action to the environments. The action sequence\n        should correspond to the ``id`` argument, and the ``id`` argument\n        should be a subset of the ``env_id`` in the last returned ``info``\n        (initially they are env_ids of all the environments). If action is\n        None, fetch unfinished step() calls instead.\n        \"\"\"\n        self._assert_is_not_closed()\n        id = self._wrap_id(id)\n        if not self.is_async:\n            assert len(action) == len(id)\n            for i, j in enumerate(id):\n                self.workers[j].send(action[i])\n            result = []\n            for j in id:\n                env_return = self.workers[j].recv()\n                env_return[-1][\"env_id\"] = j\n                result.append(env_return)\n        else:\n            if action is not None:\n                self._assert_id(id)\n                assert len(action) == len(id)\n                for act, env_id in zip(action, id):\n                    self.workers[env_id].send(act)\n                    self.waiting_conn.append(self.workers[env_id])\n                    self.waiting_id.append(env_id)\n                self.ready_id = [x for x in self.ready_id if x not in id]\n            ready_conns: List[EnvWorker] = []\n            while not ready_conns:\n                ready_conns = self.worker_class.wait(\n                    self.waiting_conn, self.wait_num, self.timeout\n                )\n            result = []\n            for conn in ready_conns:\n                waiting_index = self.waiting_conn.index(conn)\n                self.waiting_conn.pop(waiting_index)\n                env_id = self.waiting_id.pop(waiting_index)\n                # env_return can be (obs, reward, done, info) or\n                # (obs, reward, terminated, truncated, info)\n                env_return = conn.recv()\n                env_return[-1][\"env_id\"] = env_id  # Add `env_id` to info\n                result.append(env_return)\n                self.ready_id.append(env_id)\n        return_lists = tuple(zip(*result))\n        obs_list = return_lists[0]\n        try:\n            obs_stack = np.stack(obs_list)\n        except ValueError:  # different len(obs)\n            obs_stack = np.array(obs_list, dtype=object)\n        other_stacks = map(np.stack, return_lists[1:])\n        return (obs_stack, *other_stacks)  # type: ignore\n\n    def seed(\n        self,\n        seed: Optional[Union[int, List[int]]] = None,\n    ) -> List[Optional[List[int]]]:\n        \"\"\"Set the seed for all environments.\n\n        Accept ``None``, an int (which will extend ``i`` to\n        ``[i, i + 1, i + 2, ...]``) or a list.\n\n        :return: The list of seeds used in this env's random number generators.\n            The first value in the list should be the \"main\" seed, or the value\n            which a reproducer pass to \"seed\".\n        \"\"\"\n        self._assert_is_not_closed()\n        seed_list: Union[List[None], List[int]]\n        if seed is None:\n            seed_list = [seed] * self.env_num\n        elif isinstance(seed, int):\n            seed_list = [seed + i for i in range(self.env_num)]\n        else:\n            seed_list = seed\n        return [w.seed(s) for w, s in zip(self.workers, seed_list)]\n\n    def render(self, **kwargs: Any) -> List[Any]:\n        \"\"\"Render all of the environments.\"\"\"\n        self._assert_is_not_closed()\n        if self.is_async and len(self.waiting_id) > 0:\n            raise RuntimeError(\n                f\"Environments {self.waiting_id} are still stepping, cannot \"\n                \"render them now.\"\n            )\n        return [w.render(**kwargs) for w in self.workers]\n\n    def close(self) -> None:\n        \"\"\"Close all of the environments.\n\n        This function will be called only once (if not, it will be called during\n        garbage collected). This way, ``close`` of all workers can be assured.\n        \"\"\"\n        self._assert_is_not_closed()\n        for w in self.workers:\n            w.close()\n        self.is_closed = True\n\n\nclass DummyVectorEnv(BaseVectorEnv):\n    \"\"\"Dummy vectorized environment wrapper, implemented in for-loop.\n\n    .. seealso::\n\n        Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.\n    \"\"\"\n\n    def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:\n        super().__init__(env_fns, DummyEnvWorker, **kwargs)\n\n    def check_success(self):\n        return [w.check_success() for w in self.workers]\n\n    def get_segmentation_of_interest(self, segmentation_images):\n        return [\n            w.get_segmentation_of_interest(img)\n            for w, img in zip(self.workers, segmentation_images)\n        ]\n\n    def get_sim_state(self):\n        return [w.get_sim_state() for w in self.workers]\n\n    def set_init_state(\n        self,\n        init_state: Optional[Union[int, List[int], np.ndarray]] = None,\n        id: Optional[Union[int, List[int], np.ndarray]] = None,\n        **kwargs: Any,\n    ) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:\n        \"\"\"Reset the state of some envs and return initial observations.\n        If id is None, reset the state of all the environments and return\n        initial observations, otherwise reset the specific environments with\n        the given id, either an int or a list.\n        \"\"\"\n        self._assert_is_not_closed()\n        id = self._wrap_id(id)\n        if self.is_async:\n            self._assert_id(id)\n\n        # send(None) == reset() in worker\n        obs_list = []\n        for j, i in enumerate(id):\n            obs = self.workers[i].set_init_state(init_state[j])\n            obs_list.append(obs)\n        obs = np.stack(obs_list)\n        return obs\n\n\nclass SubprocVectorEnv(BaseVectorEnv):\n    \"\"\"Vectorized environment wrapper based on subprocess.\n\n    .. seealso::\n\n        Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.\n    \"\"\"\n\n    def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:\n        def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:\n            return SubprocEnvWorker(fn, share_memory=False)\n\n        super().__init__(env_fns, worker_fn, **kwargs)\n\n    def reinit_envs(self, env_fns: List[Callable[[], gym.Env]], id: Optional[Union[int, List[int], np.ndarray]] = None) -> None:\n        \"\"\"\n        Re-initializes the environments in parallel for a subset of workers.\n        \"\"\"\n        self._assert_is_not_closed()\n\n        target_ids = self._wrap_id(id)\n        \n        if len(env_fns) != len(target_ids):\n            raise ValueError(f\"Number of env_fns ({len(env_fns)}) must match the number of target ids ({len(target_ids)}).\")\n\n        target_workers = [self.workers[i] for i in target_ids]\n\n        # Send reinit command to all target workers without waiting.\n        for i, (worker, env_fn) in enumerate(zip(target_workers, env_fns)):\n            if not isinstance(worker, SubprocEnvWorker):\n                raise TypeError(\n                    f\"reinit_envs is only supported for SubprocEnvWorker, but worker {target_ids[i]} is type {type(worker).__name__}.\"\n                )\n            worker.send_reinit(env_fn)\n\n        # Wait for all target workers to confirm reinitialization.\n        results = [worker.recv_reinit() for worker in target_workers]\n        if not all(results):\n            failed_indices = [target_ids[i] for i, success in enumerate(results) if not success]\n            raise RuntimeError(f\"Worker processes {failed_indices} failed to re-initialize environment.\")\n\n    def check_success(self):\n        return [w.check_success() for w in self.workers]\n\n    def get_segmentation_of_interest(self, segmentation_images):\n        return [\n            w.get_segmentation_of_interest(img)\n            for w, img in zip(self.workers, segmentation_images)\n        ]\n\n    def get_sim_state(self):\n        return [w.get_sim_state() for w in self.workers]\n\n    def set_init_state(\n        self,\n        init_state: Optional[Union[int, List[int], np.ndarray]] = None,\n        id: Optional[Union[int, List[int], np.ndarray]] = None,\n        **kwargs: Any,\n    ) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:\n        \"\"\"Reset the state of some envs and return initial observations.\n        If id is None, reset the state of all the environments and return\n        initial observations, otherwise reset the specific environments with\n        the given id, either an int or a list.\n        \"\"\"\n        self._assert_is_not_closed()\n        id = self._wrap_id(id)\n        if self.is_async:\n            self._assert_id(id)\n\n        # send(None) == reset() in worker\n        obs_list = []\n        for j, i in enumerate(id):\n            obs = self.workers[i].set_init_state(init_state[j])\n            obs_list.append(obs)\n        obs = np.stack(obs_list)\n        return obs\n\n"
  },
  {
    "path": "siirl/execution/dag/__init__.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .node import Node, NodeRole, NodeStatus, NodeType\nfrom .task_graph import TaskGraph\nfrom .task_loader import discover_and_split_parallel_paths\nfrom .pipeline import Pipeline\nfrom .builtin_pipelines import grpo_pipeline, ppo_pipeline, dapo_pipeline\n\n__all__ = [\n    \"Node\",\n    \"NodeStatus\",\n    \"NodeType\",\n    \"NodeRole\",\n    \"TaskGraph\",\n    \"discover_and_split_parallel_paths\",\n    \"Pipeline\",\n    \"grpo_pipeline\",\n    \"ppo_pipeline\",\n    \"dapo_pipeline\"\n]\n"
  },
  {
    "path": "siirl/execution/dag/builtin_pipelines.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nBuilt-in pipeline definitions for standard RL algorithms.\n\nAll function paths are explicitly visible, making it easy to understand\nwhat each node in the pipeline executes.\n\"\"\"\n\nfrom siirl.execution.dag.pipeline import Pipeline\nfrom siirl.execution.dag.task_graph import TaskGraph\nfrom siirl.execution.dag.node import NodeType, NodeRole\n\n\ndef grpo_pipeline() -> TaskGraph:\n    \"\"\"\n    Standard GRPO (Group Relative Policy Optimization) pipeline.\n\n    Workflow:\n        1. rollout_actor: Generate sequences using the policy model\n        2. function_reward: Compute rewards for generated sequences\n        3. calculate_advantages: Calculate advantage estimates\n        4. actor_old_log_prob: Compute log probabilities with old policy (forward only)\n        5. reference_log_prob: Compute log probabilities with reference model\n        6. actor_train: Train the actor model\n\n    Returns:\n        TaskGraph: A validated task graph ready for execution\n    \"\"\"\n    pipeline = Pipeline(\"grpo_training_pipeline\", \"Standard GRPO workflow\")\n\n    # All function paths are explicitly visible!\n    pipeline.add_node(\n        \"rollout_actor\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.generate\",\n        deps=[],\n        node_type=NodeType.MODEL_INFERENCE,\n        node_role=NodeRole.ROLLOUT\n    ).add_node(\n        \"function_reward\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_reward\",\n        deps=[\"rollout_actor\"],\n        node_type=NodeType.COMPUTE,\n        node_role=NodeRole.REWARD\n    ).add_node(\n        \"calculate_advantages\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_advantage\",\n        deps=[\"function_reward\"],\n        node_type=NodeType.COMPUTE,\n        node_role=NodeRole.ADVANTAGE\n    ).add_node(\n        \"actor_old_log_prob\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob\",\n        deps=[\"calculate_advantages\"],\n        node_type=NodeType.MODEL_TRAIN,\n        node_role=NodeRole.ACTOR,\n        only_forward_compute=True\n    ).add_node(\n        \"reference_log_prob\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob\",\n        deps=[\"actor_old_log_prob\"],\n        node_type=NodeType.MODEL_TRAIN,\n        node_role=NodeRole.REFERENCE\n    ).add_node(\n        \"actor_train\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.train_actor\",\n        deps=[\"reference_log_prob\"],\n        node_type=NodeType.MODEL_TRAIN,\n        node_role=NodeRole.ACTOR\n    )\n\n    return pipeline.build()\n\n\ndef ppo_pipeline() -> TaskGraph:\n    \"\"\"\n    Standard PPO (Proximal Policy Optimization) pipeline.\n\n    Workflow:\n        1. rollout_actor: Generate sequences using the policy model\n        2. function_reward: Compute rewards for generated sequences\n        3. compute_value: Compute value function estimates (forward only)\n        4. calculate_advantages: Calculate GAE (Generalized Advantage Estimation)\n        5. actor_old_log_prob: Compute log probabilities with old policy (forward only)\n        6. reference_log_prob: Compute log probabilities with reference model\n        7. actor_train: Train the actor model\n        8. critic_train: Train the critic (value) model\n\n    Returns:\n        TaskGraph: A validated task graph ready for execution\n    \"\"\"\n    pipeline = Pipeline(\"ppo_training_pipeline\", \"Standard PPO workflow\")\n\n    pipeline.add_node(\n        \"rollout_actor\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.generate\",\n        deps=[],\n        node_type=NodeType.MODEL_INFERENCE,\n        node_role=NodeRole.ROLLOUT\n    ).add_node(\n        \"function_reward\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_reward\",\n        deps=[\"rollout_actor\"],\n        node_type=NodeType.COMPUTE,\n        node_role=NodeRole.REWARD\n    ).add_node(\n        \"compute_value\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_value\",\n        deps=[\"function_reward\"],\n        node_type=NodeType.MODEL_TRAIN,\n        node_role=NodeRole.CRITIC,\n        only_forward_compute=True\n    ).add_node(\n        \"calculate_advantages\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_advantage\",\n        deps=[\"compute_value\"],\n        node_type=NodeType.COMPUTE,\n        node_role=NodeRole.ADVANTAGE\n    ).add_node(\n        \"actor_old_log_prob\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob\",\n        deps=[\"calculate_advantages\"],\n        node_type=NodeType.MODEL_TRAIN,\n        node_role=NodeRole.ACTOR,\n        only_forward_compute=True\n    ).add_node(\n        \"reference_log_prob\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob\",\n        deps=[\"actor_old_log_prob\"],\n        node_type=NodeType.MODEL_TRAIN,\n        node_role=NodeRole.REFERENCE\n    ).add_node(\n        \"actor_train\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.train_actor\",\n        deps=[\"reference_log_prob\"],\n        node_type=NodeType.MODEL_TRAIN,\n        node_role=NodeRole.ACTOR\n    ).add_node(\n        \"critic_train\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.train_critic\",\n        deps=[\"actor_train\"],\n        node_type=NodeType.MODEL_TRAIN,\n        node_role=NodeRole.CRITIC\n    )\n\n    return pipeline.build()\n\n\ndef dapo_pipeline() -> TaskGraph:\n    \"\"\"\n    DAPO (Data-Augmented Policy Optimization) pipeline.\n\n    DAPO is a variant of GRPO with dynamic sampling filtering based on metric variance.\n    The key difference is that after computing rewards, we filter out trajectory groups\n    with zero variance (all correct or all incorrect) as they provide no learning signal.\n\n    Workflow:\n        1. rollout_actor: Generate sequences using the policy model\n        2. function_reward: Compute rewards for generated sequences\n        3. postprocess_sampling: DAPO-specific filtering based on metric variance\n        4. calculate_advantages: Calculate advantage estimates\n        5. actor_old_log_prob: Compute log probabilities with old policy (forward only)\n        6. reference_log_prob: Compute log probabilities with reference model\n        7. actor_train: Train the actor model\n\n    Returns:\n        TaskGraph: A validated task graph ready for execution\n    \"\"\"\n    pipeline = Pipeline(\"dapo_training_pipeline\", \"DAPO workflow\")\n\n    pipeline.add_node(\n        \"rollout_actor\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.generate\",\n        deps=[],\n        node_type=NodeType.MODEL_INFERENCE,\n        node_role=NodeRole.ROLLOUT\n    ).add_node(\n        \"function_reward\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_reward\",\n        deps=[\"rollout_actor\"],\n        node_type=NodeType.COMPUTE,\n        node_role=NodeRole.REWARD\n    ).add_node(\n        \"dynamic_sampling\",\n        func=\"siirl.user_interface.filter_interface.dapo.dynamic_sampling\",\n        deps=[\"function_reward\"],\n        node_type=NodeType.COMPUTE,\n        node_role=NodeRole.DYNAMIC_SAMPLING\n    ).add_node(\n        \"calculate_advantages\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_advantage\",\n        deps=[\"dynamic_sampling\"],\n        node_type=NodeType.COMPUTE,\n        node_role=NodeRole.ADVANTAGE\n    ).add_node(\n        \"actor_old_log_prob\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob\",\n        deps=[\"calculate_advantages\"],\n        node_type=NodeType.MODEL_TRAIN,\n        node_role=NodeRole.ACTOR,\n        only_forward_compute=True\n    ).add_node(\n        \"reference_log_prob\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob\",\n        deps=[\"actor_old_log_prob\"],\n        node_type=NodeType.MODEL_TRAIN,\n        node_role=NodeRole.REFERENCE\n    ).add_node(\n        \"actor_train\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.train_actor\",\n        deps=[\"reference_log_prob\"],\n        node_type=NodeType.MODEL_TRAIN,\n        node_role=NodeRole.ACTOR\n    )\n\n    return pipeline.build()\n\ndef embodied_srpo_pipeline() -> TaskGraph:\n    \"\"\"\n    Embodied AI GRPO training pipeline with data filtering and VJEPA-based reward computation.\n\n    Workflow:\n        1. rollout_actor: Environment rollout with embodied AI agent\n        2. embodied_sampling: Data verification and filtering\n        3. data_rebalance: Data rebalancing across workers (after filtering)\n        4. compute_reward: VJEPA-based reward computation\n        5. calculate_advantages: Calculate advantages (GRPO group-based)\n        6. actor_old_log_prob: Compute old actor log probabilities (forward only)\n        7. reference_log_prob: Compute reference model log probabilities\n        8. actor_train: Actor training with GRPO\n\n    Returns:\n        TaskGraph: A validated task graph ready for execution\n    \"\"\"\n    pipeline = Pipeline(\n        \"embodied_grpo_training_pipeline\",\n        \"Embodied AI GRPO training workflow with data filtering and VJEPA-based reward computation.\"\n    )\n\n    pipeline.add_node(\n        \"rollout_actor\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.generate\",\n        deps=[],\n        node_type=NodeType.MODEL_INFERENCE,\n        node_role=NodeRole.ROLLOUT\n    ).add_node(\n        \"embodied_sampling\",\n        func=\"siirl.user_interface.filter_interface.embodied.embodied_local_rank_sampling\",\n        deps=[\"rollout_actor\"], \n        node_type=NodeType.COMPUTE,\n        node_role=NodeRole.DYNAMIC_SAMPLING      \n    ).add_node(\n        \"compute_reward\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_reward\",\n        deps=[\"embodied_sampling\"],\n        node_type=NodeType.COMPUTE,\n        node_role=NodeRole.REWARD\n    ).add_node(\n        \"calculate_advantages\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_advantage\",\n        deps=[\"compute_reward\"],\n        node_type=NodeType.COMPUTE,\n        node_role=NodeRole.ADVANTAGE\n    ).add_node(\n        \"actor_old_log_prob\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob\",\n        deps=[\"calculate_advantages\"],\n        node_type=NodeType.MODEL_TRAIN,\n        node_role=NodeRole.ACTOR,\n        only_forward_compute=True\n    ).add_node(\n        \"reference_log_prob\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob\",\n        deps=[\"actor_old_log_prob\"],\n        node_type=NodeType.MODEL_TRAIN,\n        node_role=NodeRole.REFERENCE\n    ).add_node(\n        \"actor_train\",\n        func=\"siirl.dag_worker.dagworker:DAGWorker.train_actor\",  \n        deps=[\"reference_log_prob\"],\n        node_type=NodeType.MODEL_TRAIN,\n        node_role=NodeRole.ACTOR\n    )\n\n    return pipeline.build()\n\n__all__ = [\"grpo_pipeline\", \"ppo_pipeline\", \"dapo_pipeline\", \"embodied_srpo_pipeline\"]\n"
  },
  {
    "path": "siirl/execution/dag/config_loader.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nfrom typing import Any, Dict, List, Optional\n\nimport yaml\nfrom loguru import logger\n\nfrom siirl.execution.dag.node import Node, NodeRole, NodeType\nfrom siirl.execution.dag.task_graph import TaskGraph\n\n\nclass Ref:\n    \"\"\"\n    Represents the !Ref tag in YAML, used to reference other configuration paths.\n    \"\"\"\n\n    def __init__(self, path):\n        self.path = path\n\n    def __repr__(self):\n        return f\"!Ref {self.path}\"\n\n\ndef ref_constructor(loader, node):\n    \"\"\"\n    YAML constructor that parses the !Ref tag into a Ref object.\n    Args:\n        loader (yaml.Loader): The YAML loader instance.\n        node (yaml.Node): The YAML node representing the !Ref tag.\n    Returns:\n        Ref: A Ref object containing the referenced path.\n    \"\"\"\n    return Ref(loader.construct_scalar(node))\n\n\n# Register custom YAML tag handlers\n# but JSON does not support custom tags like !Ref directly.\nyaml.SafeLoader.add_constructor(\"!Ref\", ref_constructor)\n\n\ndef resolve_refs(config_item: Any, global_config: Dict[str, Any]) -> Any:\n    \"\"\"\n    Parse the reference tag (!Ref) in the configuration item and replace it with the corresponding actual value\n        in the global configuration.\n\n    Args:\n        config_item (Any): The configuration item to be parsed, which can be a\n            dictionary, list, string, or other types.\n        global_config (Dict[str, Any]): The global configuration dictionary\n            containing all referenceable configuration items.\n\n    Returns:\n        Any: The parsed configuration item, where the !Ref tags have been replaced with actual values.\n\n    Raises:\n        ValueError: An exception is thrown when the referenced path does not exist in the global configuration.\n    \"\"\"\n\n    if isinstance(config_item, dict):\n        return {k: resolve_refs(v, global_config) for k, v in config_item.items()}\n    elif isinstance(config_item, list):\n        return [resolve_refs(item, global_config) for item in config_item]\n    elif isinstance(config_item, Ref):\n        ref_path = config_item.path\n        parts = ref_path.split(\".\")\n        if parts[0] != \"global_config\":\n            raise ValueError(f\"Unsupported yaml Ref 'parts[0]', only support 'global_config'\")\n        parts = parts[1:]\n        current = global_config\n        for part in parts:\n            current = current.get(part)\n            if current is None:\n                raise ValueError(f\"Unresolved reference '{ref_path}'.\")\n        return current\n    else:\n        return config_item\n\n\nclass DAGConfigLoader:\n    \"\"\"\n    Loads, parses, and constructs TaskGraph objects from YAML or JSON files.\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    @staticmethod\n    def _parse_raw_config(raw_dag_config: Dict[str, Any], file_path: str) -> TaskGraph:\n        \"\"\"\n        Helper function to parse and build a TaskGraph from a raw dictionary configuration.\n        This function is called by load_dag_from_file to handle common logic after loading YAML/JSON.\n\n        Args:\n            raw_dag_config (Dict[str, Any]): The raw configuration dictionary loaded from the file.\n            file_path (str): The path of the configuration file.\n\n        Returns:\n            TaskGraph: A TaskGraph object constructed from the configuration.\n        \"\"\"\n        if not raw_dag_config:\n            raise ValueError(f\"The configuration file '{file_path}' is empty or has an incorrect format.\")\n\n        dag_id = raw_dag_config.get(\"dag_id\")\n        if not dag_id:\n            raise ValueError(f\"The 'dag_id' is missing in the configuration '{file_path}'.\")\n\n        description = raw_dag_config.get(\"description\")\n        global_config = raw_dag_config.get(\"global_config\", {})\n\n        # In YAML/JSON, nodes are defined as a list\n        if \"nodes\" not in raw_dag_config:\n            raise ValueError(f\"The 'nodes' list is missing in the DAG configuration\")\n        nodes_list_config = raw_dag_config.get(\"nodes\")\n        if not isinstance(nodes_list_config, list):\n            raise ValueError(f\"The 'nodes' field in the configuration '{file_path}' must be a list.\")\n\n        dag_nodes: List[Node] = []\n        node_ids = set()  # Used to store the node IDs that have appeared to verify uniqueness\n\n        for i, node_config_dict in enumerate(nodes_list_config):\n            if not isinstance(node_config_dict, dict):\n                logger.warning(f\"The configuration of the {i + 1}th node in the file '{file_path}' is not a dictionary and has been skipped.\")\n                continue\n\n            node_id = node_config_dict.get(\"node_id\")\n            if not node_id:\n                raise ValueError(f\"The 'node_id' is missing in the configuration of the {i + 1}th node in the file '{file_path}'.\")\n\n            # Verify the uniqueness of the node ID\n            if node_id in node_ids:\n                raise ValueError(f\"Duplicate node ID '{node_id}' found in the configuration file '{file_path}'.\")\n            node_ids.add(node_id)\n\n            if \"node_type\" not in node_config_dict:\n                raise ValueError(f\"Node '{node_id}' is missing 'node_type'\")\n            node_type_str = node_config_dict.get(\"node_type\").upper()\n            try:\n                node_type = NodeType[node_type_str]\n            except KeyError:\n                raise ValueError(f\"The 'node_type' ('{node_type_str}') of node '{node_id}' in the file '{file_path}' is invalid.\")\n\n            node_role_str = node_config_dict.get(\"node_role\", \"DEFAULT\").upper()\n            try:\n                node_role = NodeRole[node_role_str]\n            except KeyError:\n                raise ValueError(f\"The 'node_role' ('{node_role_str}') of node '{node_id}' in the file '{file_path}' is invalid.\")\n\n            # Whether this node only performs forward computation; defaults to False if not specified\n            only_forward_compute = node_config_dict.get(\"only_forward_compute\", False)\n            # The agent group to which this node belongs; defaults to 0 if not specified\n            agent_group = node_config_dict.get(\"agent_group\", 0)\n\n            dependencies = node_config_dict.get(\"dependencies\", [])\n            if not isinstance(dependencies, list):\n                raise ValueError(f\"The 'dependencies' of node '{node_id}' in the file '{file_path}' must be a list.\")\n\n            # Renamed 'config' from node_config_dict to avoid conflict with outer scope 'config' variable\n            node_specific_config = resolve_refs(node_config_dict.get(\"config\", {}), global_config)\n            executable_ref_str = node_config_dict.get(\"executable_ref\")\n\n            # Add node_id to its own config for easy access within the executable function (e.g., for logging)\n            node_specific_config[\"_node_id_\"] = node_id\n\n            # Multi-agent need extra params\n            agent_options = node_config_dict.get(\"agent_options\", None)\n            \n            filter_plugin = node_config_dict.get(\"filter_plugin\", None)\n            dag_node = Node(\n                node_id=node_id,\n                node_type=node_type,\n                node_role=node_role,\n                only_forward_compute=only_forward_compute,\n                agent_group=agent_group,\n                dependencies=dependencies,\n                config=node_specific_config,  # Use the renamed variable here\n                executable_ref=executable_ref_str,\n                filter_plugin=filter_plugin,\n                agent_options=agent_options\n            )\n            dag_nodes.append(dag_node)\n        task_graph = TaskGraph(dag_id)\n        task_graph.add_nodes(dag_nodes)\n        # Build adjacency lists and validate the graph\n        task_graph.build_adjacency_lists()\n        valid, msg = task_graph.validate_graph()\n        if not valid:\n            raise ValueError(f\"The graph loaded from the configuration is invalid: {msg}\")\n\n        logger.info(f\"TaskGraph '{dag_id}' built successfully with {len(task_graph.nodes)} nodes\")\n        return task_graph\n\n    @staticmethod\n    def load_from_file(file_path: str, file_type: str = \"yaml\") -> TaskGraph:\n        \"\"\"\n        Loads and constructs a TaskGraph from the specified YAML or JSON file.\n        Determines the file type based on file_type.\n\n        Args:\n            file_path (str): The path of the configuration file.\n            file_type (str): The type of the configuration file, default is yaml\n\n        Returns:\n            TaskGraph: A TaskGraph object constructed from the configuration file.\n        \"\"\"\n        raw_dag_config: Optional[Dict[str, Any]] = None  # Initialize to None\n\n        try:\n            with open(file_path, \"r\", encoding=\"utf-8\") as f:\n                if file_type in [\"yaml\", \"yml\"]:\n                    raw_dag_config = yaml.safe_load(f)\n                elif file_type == \"json\":\n                    raw_dag_config = json.load(f)\n                else:\n                    raise ValueError(f\"Unsupported file type: '{file_type}'. Please use yaml, yml, or json.\")\n        except FileNotFoundError:\n            logger.error(f\"The configuration file '{file_path}' was not found.\")\n            raise\n        except yaml.YAMLError as e:  # Specific exception for YAML parsing errors\n            logger.error(f\"Failed to parse the YAML file '{file_path}': {e}\")\n            raise\n        except json.JSONDecodeError as e:  # Specific exception for JSON parsing errors\n            logger.error(f\"Failed to parse the JSON file '{file_path}': {e}\")\n            raise\n        except Exception as e:  # Catch-all for other potential loading errors\n            logger.error(f\"An unknown error occurred while loading the configuration file '{file_path}': {e}\")\n            raise\n\n        # Check if loading resulted in None (e.g., empty file might be parsed as None by yaml/json libs)\n        if raw_dag_config is None:\n            raise ValueError(f\"The result of loading the configuration file '{file_path}' is empty. It might be an empty file or have a format issue.\")\n\n        # Delegate the rest of the parsing to the helper function\n        return DAGConfigLoader._parse_raw_config(raw_dag_config, file_path)\n"
  },
  {
    "path": "siirl/execution/dag/node.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib\nfrom enum import Enum\nfrom typing import Any, Callable, Dict, List, Optional, Set, Tuple\n\nimport dacite\nfrom loguru import logger\n\nfrom siirl.params import log_dict_formatted\nfrom siirl.params.model_args import AgentArguments\nfrom siirl.models.loader import load_tokenizer\nimport dacite\n\n\ndef dynamic_load_function(func_path: str):\n    if \".\" not in func_path:\n        raise ValueError(f\"{func_path} is not a correct module path\")\n    \n    module_path, func_name = func_path.rsplit(\".\", 1)  \n    \n\n    module = importlib.import_module(module_path)\n    \n\n    target_func = getattr(module, func_name)\n\n    if not callable(target_func):\n        raise TypeError(f\"{func_path} is not an valid path\")\n    \n    return target_func\nclass NodeType(Enum):\n    \"\"\"\n    Define the types of nodes in the DAG.\n    \"\"\"\n\n    COMPUTE = \"COMPUTE\"  # General computing task\n    DATA_LOAD = \"DATA_LOAD\"  # Load data from DataLoader\n    ENV_INTERACT = \"ENV_INTERACT\"  # Interact with the environment\n    MODEL_INFERENCE = \"MODEL_INFERENCE\"  # Model inference\n    MODEL_TRAIN = \"MODEL_TRAIN\"  # Model training\n    PUT_TO_BUFFER = \"PUT_TO_BUFFER\"  # Put data into the distributed buffer\n    GET_FROM_BUFFER = \"GET_FROM_BUFFER\"  # Get data from the distributed buffer\n    BARRIER_SYNC = \"BARRIER_SYNC\"  # Global synchronization point\n    CUSTOM = \"CUSTOM\"  # User-defined node type, executed using an executable\n\n\nclass NodeRole(Enum):\n    \"\"\"\n    Define the roles that a node plays in a distributed reinforcement learning framework.\n    This helps with specific scheduling or resource allocation.\n    \"\"\"\n\n    DEFAULT = \"DEFAULT\"  # Default\n    ACTOR = \"ACTOR\"  # Actor\n    ADVANTAGE = \"ADVANTAGE\"  # ADVANTAGE\n    CRITIC = \"CRITIC\"  # Critic\n    ROLLOUT = \"ROLLOUT\"  # Rollout\n    REFERENCE = \"REFERENCE\"  # Reference\n    REWARD = \"REWARD\"  # Reward\n\n    DYNAMIC_SAMPLING = \"DYNAMIC_SAMPLING\"  # Dynamic sampling in databuffer\n\n\nclass NodeStatus(Enum):\n    \"\"\"\n    Define the execution status of a DAG node.\n    \"\"\"\n\n    PENDING = \"PENDING\"  # Waiting for dependencies to complete\n    READY = \"READY\"  # Dependencies completed, ready to execute\n    RUNNING = \"RUNNING\"  # Currently executing\n    COMPLETED = \"COMPLETED\"  # Execution completed successfully\n    FAILED = \"FAILED\"  # Execution failed\n    SKIPPED = \"SKIPPED\"  # Skipped\n\n\nclass AgentProcess:\n    def __init__(self, agent_options: AgentArguments, node_config):\n        from siirl.dag_worker.constants import DAGConstants\n        self.env = None\n        self.post_process = None\n        self.pre_process = None\n        intern_config = node_config.get(DAGConstants.INTERN_CONFIG)\n        if intern_config is None:\n            return\n        # init tokenizer for each node\n        if hasattr(intern_config.model, 'trust_remote_code'):\n            intern_config.model.trust_remote_code = True\n        tokenizer_module = load_tokenizer(model_args=intern_config.model)\n        self.tokenizer = tokenizer_module.get(\"tokenizer\")\n        if agent_options is None:\n            return\n        process_path: str = agent_options.process_path\n        self.pre_process_kwargs: dict = agent_options.pre_process_kwargs\n        self.post_process_kwargs: dict = agent_options.post_process_kwargs\n        self._init_process_handle(process_path)\n\n        self.env_path = agent_options.env_path\n        self.env_managers = [{}]  # map str to env instance\n\n        if self.env_path:\n            self.init_env_class()\n\n        self.env_handles = None\n\n    def load_attr(self, file_path, attr_name):\n        try:\n            module_name = f\"{hash(file_path) & 0xFFFFFFF}\"\n            spec = importlib.util.spec_from_file_location(module_name, file_path)\n            module = importlib.util.module_from_spec(spec)\n            spec.loader.exec_module(module)\n        except Exception as e:\n            raise RuntimeError(f\"Error loading class from '{file_path}': {e}\") from e\n        try:\n            attr = getattr(module, attr_name)\n            return attr\n        except Exception as e:\n            logger.warning(f\"Error loading attr from '{file_path}:{e}\")\n        return None\n\n    def init_env_class(self):\n        self.env = []\n        for env_path in self.env_path:\n            file_path, class_ref = env_path.split(\":\")\n            env = self.load_attr(file_path, class_ref)\n            self.env.append(env)\n\n    def _init_process_handle(self, process_path):\n        if process_path is not None:\n            self.pre_process = self.load_attr(process_path, \"pre_process\")\n            self.post_process = self.load_attr(process_path, \"post_process\")\n\n    # each agent may have different tokenizer\n    # so, we make sure preprocess get str instead of list except get List[int] from dataloader in first agent\n    def apply_pre_process(self, prompt: Optional[Tuple[str, List]], obs: Optional[Tuple[str, List]]) -> str:\n        \"\"\"\n        Applies preprocessing to the input prompt (and optional environment observation) to generate a templated prompt.\n\n        Converts raw prompts to token IDs (if needed) and uses a custom preprocessing function (if configured)\n        to format the prompt (e.g., adding chat templates, incorporating observations).\n\n        Args:\n            prompt: Input prompt to preprocess. Can be either a raw string (to be tokenized) or a list of token IDs.\n            obs: Optional environment observation (tuple of string/list) to incorporate into the prompt\n            (for agent-environment interactions).\n\n        Returns:\n            Tuple[List[int], List[int]]:\n                - Original prompt (converted to token IDs if it was a string).\n                - Templated prompt (token IDs after preprocessing, e.g., with chat templates or observations added).\n        \"\"\"\n        templated_prompt = None\n        if isinstance(prompt, str):\n            prompt = self.tokenizer.encode(prompt)\n        if self.pre_process:\n            templated_prompt = self.pre_process(self.tokenizer, prompt, obs, **self.pre_process_kwargs)\n        else:\n            templated_prompt = prompt\n        return prompt, templated_prompt\n\n    # each agent may have different tokenizer\n    # so, we make sure postprocess get list[int] and return str\n    def apply_post_process(self, oridinal_prompt, templated_prompt, response) -> Tuple[List[int], List[int]]:\n        \"\"\"\n        Applies postprocessing to the generation result to combine the original prompt with the response,\n        and generates a mask for the response tokens.\n\n        Converts raw string responses to token IDs (if needed), merges the prompt with the response,\n        and creates a binary mask to identify response tokens (for training tasks like next-token prediction).\n\n        Note: Each agent may use a different tokenizer, so this method ensures input is list of token IDs\n        and returns properly formatted outputs (decoded string for original prompt,\n        token IDs for templated prompt/mask).\n\n        Args:\n            oridinal_prompt: Original prompt (list of token IDs) before generation.\n            templated_prompt: Preprocessed templated prompt (list of token IDs) used for generation.\n            response: Generated response to postprocess. Can be either a raw string (to be tokenized)\n            or a list of token IDs.\n\n        Returns:\n            Tuple[str, List[int], List[int]]:\n                - Decoded original prompt (string, merged with response tokens).\n                - Templated prompt merged with response tokens (list of token IDs, for model input).\n                - Response mask (binary list: 1 for response tokens, 0 otherwise; same length as response).\n        \"\"\"\n        if isinstance(response, str):\n            response = self.tokenizer.encode(response)\n        if self.post_process:\n            oridinal_prompt = self.post_process(self.tokenizer, oridinal_prompt, response, **self.post_process_kwargs)\n        else:\n            oridinal_prompt = oridinal_prompt + response\n        response_mask = [1] * len(response)\n        templated_prompt = templated_prompt + response\n        return self.tokenizer.decode(oridinal_prompt), templated_prompt, response_mask\n\n\nclass Node:\n    \"\"\"\n    Represents a node (task unit) in the DAG.\n    \"\"\"\n\n    def __init__(\n        self,\n        node_id: str,\n        node_type: NodeType,\n        node_role: NodeRole = NodeRole.DEFAULT,\n        only_forward_compute: bool = False,\n        agent_group: int = 0,\n        dependencies: Optional[List[str]] = None,\n        config: Optional[Dict[str, Any]] = None,\n        executable_ref: Optional[str] = None,\n        filter_plugin: Optional[Callable] = None,\n        agent_options: AgentArguments = None,\n        retry_limit: int = 0,\n    ):\n        \"\"\"\n        Initialize a node.\n\n        Args:\n            node_id (str): The unique identifier of the node.\n            node_type (NodeType): The type of the node.\n            node_role (NodeRole): The role played by the node. Defaults to NodeRole.DEFAULT.\n            dependencies (Optional[List[str]]): A list of IDs of other nodes that this node depends on.\n            Defaults to an empty list.\n            config (Optional[Dict[str, Any]]): Specific configuration information for the node.\n            Defaults to an empty dictionary.\n            executable_ref (Optional[str]): A string reference to the Python function for the node's execution logic\n                                           (e.g., \"my_module.my_submodule.my_function\").\n                                           If None, it means the node may have built-in logic or be handled by\n                                           an external executor.\n            retry_limit (int): The maximum number of retries when the node execution fails. Defaults to 0 (no retries).\n        \"\"\"\n        if not isinstance(node_id, str) or not node_id:\n            raise ValueError(\"node_id must be a non-empty string.\")\n        if not isinstance(node_type, NodeType):\n            raise ValueError(\"node_type must be a member of the NodeType enum.\")\n        if not isinstance(node_role, NodeRole):\n            raise ValueError(\"node_role must be a member of the NodeRole enum.\")\n        if (\n            node_type not in [NodeType.COMPUTE, NodeType.MODEL_TRAIN, NodeType.MODEL_INFERENCE]\n            and node_role != NodeRole.DEFAULT\n        ):\n            raise ValueError(\"The role type of non-model nodes must be DEFAULT\")\n\n        self.node_id: str = node_id\n        self.node_type: NodeType = node_type\n        self.node_role: NodeRole = node_role\n        self.only_forward_compute: bool = only_forward_compute\n        self.agent_group: int = agent_group\n        self.dependencies: List[str] = dependencies or []\n        self.config: Dict[str, Any] = config or {}\n        self.executable_ref: Optional[str] = executable_ref\n        self.retry_limit: int = retry_limit\n        self.retries_done: int = 0\n\n        self.async_rollout = None\n        self.mode = \"sync\"\n        self._executable: Optional[Callable] = None\n        self.output: Any = None  # Store the result of the node execution\n        self.error_info: Optional[str] = None  # Store error information when the node fails\n        if isinstance(agent_options, Dict):\n            agent_options: AgentArguments = dacite.from_dict(\n                data_class=AgentArguments,\n                data=agent_options,\n                config=dacite.Config(strict=False)\n            )\n        self.agent_options = agent_options\n        self.agent_process = AgentProcess(agent_options, self.config)\n        if self.executable_ref:\n            self._resolve_executable()\n\n        self.status: NodeStatus = NodeStatus.PENDING\n        if filter_plugin:\n            self.filter_plugin = dynamic_load_function(filter_plugin)\n    def _resolve_executable(self) -> None:\n        \"\"\"\n        Dynamically import and obtain the executable function based on the executable_ref string.\n\n        Supports two formats:\n        1. \"module.path:ClassName.method\" - imports module.path, then gets ClassName.method\n        2. \"module.path.function\" - imports module.path, then gets function\n        \"\"\"\n        if not self.executable_ref:\n            self._executable = None\n            return\n\n        try:\n            # Check if colon separator is present (format: module.path:ClassName.method)\n            if \":\" in self.executable_ref:\n                module_path, attr_path = self.executable_ref.split(\":\", 1)\n                module = importlib.import_module(module_path)\n                # Handle nested attributes (e.g., \"ClassName.method\")\n                obj = module\n                for attr_name in attr_path.split(\".\"):\n                    obj = getattr(obj, attr_name)\n                self._executable = obj\n            else:\n                # Fall back to original behavior (format: module.path.function)\n                module_path, function_name = self.executable_ref.rsplit(\".\", 1)\n                module = importlib.import_module(module_path)\n                self._executable = getattr(module, function_name)\n\n            if not callable(self._executable):\n                raise AttributeError(f\"The object resolved from '{self.executable_ref}' is not callable.\")\n        except (ImportError, AttributeError, ValueError) as e:\n            raise ImportError(f\"Failed to load the executable function from '{self.executable_ref}': {e}\") from e\n\n    @property\n    def executable(self) -> Optional[Callable]:\n        \"\"\"Return the resolved executable function.\"\"\"\n        return self._executable\n\n    @executable.setter\n    def executable(self, execute: Optional[Callable]):\n        \"\"\"Set the executable function for this node.\"\"\"\n        self._executable = execute\n\n    def add_dependency(self, dependency_id: str) -> None:\n        \"\"\"\n        Add a dependency.\n        Args:\n            dependency_id (str): The ID of the dependent node.\n        \"\"\"\n        if dependency_id not in self.dependencies:\n            self.dependencies.append(dependency_id)\n\n    def remove_dependency(self, dependency_id: str) -> None:\n        \"\"\"\n        Remove a dependency.\n        Args:\n            dependency_id (str): The ID of the dependency node to be removed.\n        \"\"\"\n        if dependency_id in self.dependencies:\n            self.dependencies.remove(dependency_id)\n\n    def is_ready(self, completed_node_ids: Set[str]) -> bool:\n        \"\"\"\n        Check if all dependencies of this node have been completed.\n        Args:\n            completed_node_ids (Set[str]): A set of IDs of completed nodes.\n        Returns:\n            bool: True if all dependencies are completed, otherwise False.\n        \"\"\"\n        if self.status != NodeStatus.PENDING:  # Only nodes in PENDING status can become READY\n            return False\n        return all(dep_id in completed_node_ids for dep_id in self.dependencies)\n\n    def update_status(self, new_status: NodeStatus, error_info: Optional[str] = None) -> None:\n        \"\"\"Update the node status and record error information (if applicable).\"\"\"\n        self.status = new_status\n        if error_info:\n            self.error_info = error_info\n        if new_status == NodeStatus.FAILED:\n            logger.error(f\"Node {self.node_id} execution failed: {error_info or 'Unknown error'}\")\n        elif new_status == NodeStatus.COMPLETED:\n            self.error_info = None  # Clear previous error information\n\n    def update_config(self, new_config_items: Dict[str, Any], overwrite: bool = True) -> None:\n        \"\"\"\n        Update the node's configuration.\n\n        Args:\n            new_config_items (Dict[str, Any]): A dictionary containing configuration keys and values to add or update.\n            overwrite (bool): If True (default), existing keys in the node's config will be overwritten\n                              by those in new_config_items. If False, existing keys will be preserved,\n                              and only new keys from new_config_items will be added.\n        \"\"\"\n        if not isinstance(new_config_items, dict):\n            logger.warning(\n                f\"Node {self.node_id}: Failed to update config. Provided new_config_items is not a dictionary (\"\n                f\"type: {type(new_config_items)}).\"\n            )\n            return\n\n        if overwrite:\n            self.config.update(new_config_items)\n            logger.info(f\"Node {self.node_id}: Configuration updated (overwrite=True).\")\n        else:\n            for key, value in new_config_items.items():\n                if key not in self.config:\n                    self.config[key] = value\n            logger.info(f\"Node {self.node_id}: Configuration updated (overwrite=False, existing keys preserved).\")\n\n        log_dict_formatted(self.config, title=f\"Node {self.node_id} current config\", log_level=\"debug\")\n\n    def can_retry(self) -> bool:\n        \"\"\"Check if the node can be retried.\"\"\"\n        return self.status == NodeStatus.FAILED and self.retries_done < self.retry_limit\n\n    def increment_retry_count(self) -> None:\n        \"\"\"Increment the retry count.\"\"\"\n        self.retries_done += 1\n\n    def run(self, **kwargs: Any) -> Any:\n        \"\"\"\n        Execute the task of the node.\n        Args:\n            **kwargs: Parameters passed to the executable function, usually the outputs of its dependent nodes.\n        Returns:\n            Any: The result of the node execution.\n        \"\"\"\n        logger.debug(\n            f\"Starting to execute node: {self.node_id} (Type: {self.node_type.value}, Role: {self.node_role.value})\"\n        )\n        self.update_status(NodeStatus.RUNNING)\n\n        if not self.executable:\n            # For nodes without an executable reference, they may be handled by an external system,\n            # or they are purely structural nodes (e.g., BARRIER_SYNC, whose logic is in the scheduler).\n            # one implement for barrier...\n            if self.node_type == NodeType.BARRIER_SYNC and kwargs.get(\"do_barrier\", False):\n                import torch.distributed as dist\n\n                logger.debug(f\"Node {self.node_id} block before barrier ...\")\n                dist.barrier(group=kwargs.get(\"barrier_group\", None))\n\n            logger.debug(f\"Node {self.node_id} has no executable function, skipping execution.\")\n            self.output = None  # Or set a specific output based on the node type\n            return self.output\n\n        try:\n            import inspect\n\n            # Check if the executable is an unbound method (needs self parameter)\n            # If it's a method defined in a class but not bound to an instance, bind it now\n            executable = self._executable\n\n            # Check if this is an unbound method that needs 'self'\n            # This happens when the method is loaded from \"module:Class.method\" format\n            if inspect.isfunction(executable) or inspect.ismethod(executable):\n                sig = inspect.signature(executable)\n                params = list(sig.parameters.keys())\n\n                # If the first parameter is 'self' and it's not bound yet, we need to bind it\n                if params and params[0] == 'self':\n                    # Get the DAGWorker instance from kwargs\n                    # The calling code should pass the DAGWorker instance\n                    dag_worker = kwargs.pop('_dag_worker_instance', None)\n                    if dag_worker is None:\n                        raise ValueError(\n                            f\"Node {self.node_id}: Executable '{self.executable_ref}' requires 'self' parameter, \"\n                            f\"but '_dag_worker_instance' was not provided in kwargs. \"\n                            f\"Please pass _dag_worker_instance=self when calling node.run().\"\n                        )\n                    # Bind the method to the instance\n                    import types\n                    executable = types.MethodType(executable, dag_worker)\n\n            # Simplification: Pass all kwargs directly, and the user function handles them\n            node_output = executable(**kwargs)\n            self.output = node_output\n            self.update_status(NodeStatus.COMPLETED)\n            logger.debug(f\"Node {self.node_id} execution completed.\")\n            return self.output\n        except Exception as e:\n            error_message = f\"An error occurred while executing node {self.node_id}: {e}\"\n            self.update_status(NodeStatus.FAILED, error_message)\n            # An exception can be raised here, or the scheduler can handle the FAILED status\n            raise RuntimeError(error_message) from e\n\n    def __repr__(self) -> str:\n        return (\n            f\"Node(node_id='{self.node_id}', type='{self.node_type.value}', role='{self.node_role.value}', \"\n            f\"agent_group='{self.agent_group}', only_forward_compute='{self.only_forward_compute}', \"\n            f\"status='{self.status.value}', deps={len(self.dependencies)})\"\n        )\n\n    def copy(self) -> \"Node\":\n        new_node = Node(\n            node_id=self.node_id,\n            node_type=self.node_type,\n            node_role=self.node_role,\n            dependencies=list(self.dependencies),\n            config=dict(self.config),\n            executable_ref=self.executable_ref,\n            retry_limit=self.retry_limit,\n            only_forward_compute=self.only_forward_compute,\n            agent_group=self.agent_group,\n            filter_plugin=getattr(self, \"filter_plugin\", None),\n            agent_options=self.agent_options,\n        )\n        new_node.status = self.status\n        new_node.retries_done = self.retries_done\n        new_node._executable = self._executable\n        return new_node\n"
  },
  {
    "path": "siirl/execution/dag/pipeline.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nSimplified Pipeline Builder API for defining DAG workflows in Python code.\n\nThis module provides a clean, intuitive API for users to define their training\npipelines directly in Python, with explicit function bindings for each node.\n\"\"\"\n\nfrom typing import Callable, Optional, List, Dict, Any, Union\nfrom dataclasses import dataclass, field\nfrom loguru import logger\n\nfrom siirl.execution.dag.node import Node, NodeType\nfrom siirl.execution.dag.task_graph import TaskGraph\n\n\n@dataclass\nclass NodeConfig:\n    \"\"\"Configuration for a pipeline node.\"\"\"\n    agent_group: int = 0\n    config: Dict[str, Any] = field(default_factory=dict)\n\n\nclass Pipeline:\n    \"\"\"\n    Simplified Pipeline builder for defining training workflows in Python.\n\n    Users can directly specify the function to execute for each node,\n    making the entire workflow transparent and easy to understand.\n\n    Example:\n        >>> pipeline = Pipeline(\"grpo_training\")\n        >>> pipeline.add_node(\n        ...     \"rollout\",\n        ...     func=\"siirl.dag_worker.dagworker:DAGWorker.generate\",\n        ...     deps=[]\n        ... )\n        >>> pipeline.add_node(\n        ...     \"reward\",\n        ...     func=\"siirl.dag_worker.dagworker:DAGWorker.compute_reward\",\n        ...     deps=[\"rollout\"]\n        ... )\n        >>> graph = pipeline.build()\n    \"\"\"\n\n    def __init__(self, pipeline_id: str, description: str = \"\"):\n        \"\"\"\n        Initialize a Pipeline builder.\n\n        Args:\n            pipeline_id: Unique identifier for this pipeline\n            description: Human-readable description of the pipeline\n        \"\"\"\n        self.pipeline_id = pipeline_id\n        self.description = description\n        self._nodes: Dict[str, Dict[str, Any]] = {}\n\n    def add_node(\n        self,\n        node_id: str,\n        func: Union[str, Callable],\n        deps: Optional[List[str]] = None,\n        config: Optional[NodeConfig] = None,\n        **kwargs\n    ) -> \"Pipeline\":\n        \"\"\"\n        Add a node to the pipeline.\n\n        Args:\n            node_id: Unique identifier for this node\n            func: Function to execute. Can be:\n                  - String path: \"module.path:ClassName.method\" or \"module.path:function\"\n                  - Callable: Direct function reference\n            deps: List of node IDs that this node depends on\n            config: Node configuration (optional)\n            **kwargs: Additional node parameters (e.g., only_forward_compute)\n\n        Returns:\n            self: For method chaining\n\n        Raises:\n            ValueError: If node_id already exists in the pipeline\n        \"\"\"\n        if node_id in self._nodes:\n            raise ValueError(f\"Node '{node_id}' already exists in pipeline '{self.pipeline_id}'\")\n\n        deps = deps or []\n        config = config or NodeConfig()\n\n        self._nodes[node_id] = {\n            \"func\": func,\n            \"deps\": deps,\n            \"config\": config,\n            \"kwargs\": kwargs\n        }\n\n        logger.debug(f\"Added node '{node_id}' to pipeline '{self.pipeline_id}'\")\n        return self\n\n    def build(self) -> TaskGraph:\n        \"\"\"\n        Build and validate the TaskGraph from the pipeline definition.\n\n        Returns:\n            TaskGraph: A validated TaskGraph ready for execution\n\n        Raises:\n            ValueError: If the pipeline is invalid (e.g., circular dependencies)\n        \"\"\"\n        from siirl.execution.dag.node import NodeRole\n\n        task_graph = TaskGraph(graph_id=self.pipeline_id)\n\n        for node_id, node_info in self._nodes.items():\n            # Extract node_type and node_role from kwargs if provided, otherwise use defaults\n            kwargs = node_info[\"kwargs\"].copy()\n            node_type = kwargs.pop(\"node_type\", NodeType.COMPUTE)\n            node_role = kwargs.pop(\"node_role\", NodeRole.DEFAULT)\n\n            # Create Node instance\n            node = Node(\n                node_id=node_id,\n                node_type=node_type,\n                node_role=node_role,\n                dependencies=node_info[\"deps\"],\n                agent_group=node_info[\"config\"].agent_group,\n                config=node_info[\"config\"].config,\n                **kwargs\n            )\n\n            # Bind the executable function\n            func = node_info[\"func\"]\n            if isinstance(func, str):\n                # Function specified as string path\n                node.executable_ref = func\n                node._resolve_executable()\n            else:\n                # Direct callable\n                node.executable = func\n\n            task_graph.add_node(node)\n\n        # Build adjacency lists and validate\n        task_graph.build_adjacency_lists()\n        valid, msg = task_graph.validate_graph()\n        if not valid:\n            raise ValueError(f\"Invalid pipeline '{self.pipeline_id}': {msg}\")\n\n        logger.info(f\"Pipeline '{self.pipeline_id}' built successfully with {len(self._nodes)} nodes\")\n        return task_graph\n\n    def visualize(self, output_path: str = None, directory: str = \"./\"):\n        \"\"\"\n        Visualize the pipeline structure.\n\n        Args:\n            output_path: Filename for the visualization (without extension)\n            directory: Directory to save the visualization\n\n        Returns:\n            TaskGraph: The built task graph\n        \"\"\"\n        graph = self.build()\n        if output_path:\n            graph.save_dag_pic(filename=output_path, directory=directory)\n            logger.info(f\"Pipeline visualization saved to {directory}/{output_path}\")\n        return graph\n"
  },
  {
    "path": "siirl/execution/dag/task_graph.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom typing import Any, Dict, List, Optional, Tuple\n\nfrom loguru import logger\n\nfrom siirl.execution.dag.node import Node, NodeRole, NodeStatus, NodeType\n\n\nclass TaskGraph:\n    \"\"\"\n    Represents a Directed Acyclic Graph (DAG) of tasks, composed of multiple Node objects and their dependencies.\n    \"\"\"\n\n    def __init__(self, graph_id: str):\n        \"\"\"\n        Initialize a task graph.\n        Parameters:\n            graph_id (str): The unique identifier of the graph.\n        \"\"\"\n        self.graph_id: str = graph_id\n        self.nodes: Dict[str, Node] = {}  # node_id -> Node object\n        # Forward adjacency list: node_id -> list of node_ids that depend on it (dependents)\n        self.adj: Dict[str, List[str]] = {}\n        # Reverse adjacency list (more commonly used for dependency checking):\n        #       node_id -> list of node_ids it depends on (dependencies)\n        self.rev_adj: Dict[str, List[str]] = {}\n\n    def add_node(self, node: Node) -> None:\n        \"\"\"\n        Add a node to the graph.\n        If the node already exists, it will be updated (Note: This may cause inconsistent dependencies.\n            It is recommended to delete and then add, or use with caution).\n        Parameters:\n            node (Node): The Node object to add.\n        \"\"\"\n        if not isinstance(node, Node):\n            raise ValueError(\"Only Node type objects can be added to the graph.\")\n        if node.node_id in self.nodes:\n            logger.warning(f\"Warning: Node {node.node_id} already exists in the graph. It will be replaced.\")\n\n        self.nodes[node.node_id] = node\n        # Update the adjacency list\n        self._update_adj_for_node(node)\n\n    def add_nodes(self, nodes: List[Node]) -> None:\n        \"\"\"\n        Add multiple nodes to the graph in batch.\n        If a node already exists, it will be updated (Note: This may cause inconsistent dependencies.\n            It is recommended to delete and then add, or use with caution).\n        Parameters:\n            nodes (List[Node]): A list of Node objects to add.\n        \"\"\"\n        for node in nodes:\n            self.add_node(node)\n\n    def _update_adj_for_node(self, node: Node) -> None:\n        \"\"\"Update the adjacency list and reverse adjacency list for a single node.\"\"\"\n        # Initialize the adjacency list for the current node\n        self.adj.setdefault(node.node_id, [])\n        self.rev_adj.setdefault(node.node_id, [])\n\n        # Update the reverse adjacency list (node -> its dependencies)\n        self.rev_adj[node.node_id] = list(node.dependencies)  # Ensure it's a copy of the list\n\n        # Update the forward adjacency list (dependency -> node)\n        for dep_id in node.dependencies:\n            if dep_id not in self.nodes:\n                # Allow adding nodes whose dependencies do not yet exist,\n                # but it will be checked when validating the graph\n                pass\n            self.adj.setdefault(dep_id, []).append(node.node_id)\n            # Remove duplicates, although there should usually be no duplicate dependencies\n            self.adj[dep_id] = list(set(self.adj[dep_id]))\n\n    def build_adjacency_lists(self) -> None:\n        \"\"\"\n        Completely (re)build the adjacency list and reverse adjacency list based on\n            the dependencies of all nodes in the graph.\n        Call this method after all nodes have been added,\n            or after significant changes to the node dependencies.\n        \"\"\"\n        self.adj.clear()\n        self.rev_adj.clear()\n\n        for node_id, node in self.nodes.items():\n            self.adj.setdefault(node_id, [])\n            self.rev_adj.setdefault(node_id, list(node.dependencies))  # node -> its dependencies\n            for dep_id in node.dependencies:\n                if dep_id not in self.nodes:\n                    # Allow dependencies on nodes not yet defined in the graph. validate_graph will handle this.\n                    self.adj.setdefault(dep_id, [])  # Ensure dep_id is in adj\n                else:\n                    self.adj.setdefault(dep_id, []).append(node_id)  # dependency -> node\n\n        # Clean up duplicates in adj (if necessary)\n        for node_id in self.adj:\n            self.adj[node_id] = list(set(self.adj[node_id]))\n\n    def get_node(self, node_id: str) -> Optional[Node]:\n        \"\"\"\n        Get the node with the specified ID.\n        Parameters:\n            node_id (str): The node ID.\n        Returns:\n            Optional[Node]: The Node object if found, otherwise None.\n        \"\"\"\n        return self.nodes.get(node_id)\n\n    def get_dependencies(self, node_id: str) -> List[Node]:\n        \"\"\"\n        Get all direct dependent nodes of the specified node.\n        Parameters:\n            node_id (str): The node ID.\n        Returns:\n            List[Node]: A list of dependent Node objects.\n        \"\"\"\n        if node_id not in self.nodes:\n            return []\n        # return [self.nodes[dep_id] for dep_id in self.nodes[node_id].dependencies if dep_id in self.nodes]\n        # It's more straightforward to use rev_adj if it's already built\n        if not self.rev_adj or node_id not in self.rev_adj:  # Ensure rev_adj is built\n            self.build_adjacency_lists()\n\n        return [self.nodes[dep_id] for dep_id in self.rev_adj.get(node_id, []) if dep_id in self.nodes]\n\n    def get_dependents(self, node_id: str) -> List[Node]:\n        \"\"\"\n        Get all nodes that directly depend on the specified node.\n        Parameters:\n            node_id (str): The node ID.\n        Returns:\n            List[Node]: A list of Node objects that depend on this node.\n        \"\"\"\n        if node_id not in self.nodes:\n            return []\n        if not self.adj or node_id not in self.adj:  # Ensure adj is built\n            self.build_adjacency_lists()\n\n        return [self.nodes[dep_id] for dep_id in self.adj.get(node_id, []) if dep_id in self.nodes]\n\n    def get_downstream_nodes(self, node_id: str) -> List[Node]:\n        \"\"\"Get the direct downstream nodes of a node\"\"\"\n        return self.get_dependents(node_id)\n\n    def get_entry_nodes(self) -> List[Node]:\n        \"\"\"\n        Get all entry nodes in the graph that have no dependencies.\n        Returns:\n            List[Node]: A list of entry Node objects.\n        \"\"\"\n        if not self.rev_adj:  # Ensure rev_adj is built\n            self.build_adjacency_lists()\n        return [node for node_id, node in self.nodes.items() if not self.rev_adj.get(node_id)]\n\n    def get_exit_nodes(self) -> List[Node]:\n        \"\"\"\n        Get all exit nodes in the graph that have no subsequent dependent nodes.\n        Returns:\n            List[Node]: A list of exit Node objects.\n        \"\"\"\n        if not self.adj:  # Ensure adj is built\n            self.build_adjacency_lists()\n        return [node for node_id, node in self.nodes.items() if not self.adj.get(node_id)]\n\n    def validate_graph(self) -> Tuple[bool, Optional[str]]:\n        \"\"\"\n        Validate the validity of the graph.\n        1. Check if all node dependencies exist in the graph.\n        2. Check if there are circular dependencies (determine if it's a DAG).\n        Returns:\n            Tuple[bool, Optional[str]]: (Is valid, Error message or None)\n        \"\"\"\n        # 1. Check dependency existence\n        for node_id, node in self.nodes.items():\n            for dep_id in node.dependencies:\n                if dep_id not in self.nodes:\n                    return False, f\"The dependency '{dep_id}' of node '{node_id}' does not exist in the graph.\"\n\n        # 2. Check for circular dependencies (using topological sorting)\n        try:\n            self.get_topological_sort()\n            return True, None\n        except ValueError as e:  # get_topological_sort will raise a ValueError when a cycle is detected\n            return False, str(e)\n\n    def get_topological_sort(self) -> List[str]:\n        \"\"\"\n        Get the topological sorting of the nodes.\n        If the graph is not a DAG (i.e., there is a cycle), a ValueError will be raised.\n        Use Kahn's algorithm.\n        Returns:\n            List[str]: A list of node IDs in topological order.\n        \"\"\"\n        if not self.nodes:\n            return []\n\n        # Ensure the adjacency list is up-to-date\n        self.build_adjacency_lists()\n\n        in_degree = {node_id: 0 for node_id in self.nodes}\n        for node_id in self.nodes:\n            for dependent_id in self.adj.get(node_id, []):\n                in_degree[dependent_id] += 1\n\n        # In some cases, if a node is in the keys of adj but not in nodes\n        # (e.g., a dependency is declared but not added as a node itself),\n        # in_degree may contain keys that are not in self.nodes. Ensure only nodes in the graph are processed.\n        # However, it's better to handle this in build_adjacency_lists or add_node.\n        # Here, it's assumed that the keys of in_degree are all valid nodes in self.nodes.\n\n        queue = [node_id for node_id in self.nodes if in_degree[node_id] == 0]\n        topological_order = []\n\n        while queue:\n            u = queue.pop(0)\n            topological_order.append(u)\n\n            for v_id in self.adj.get(u, []):  # v_id is a subsequent node of u\n                if v_id in in_degree:  # Ensure v_id is a node in the graph\n                    in_degree[v_id] -= 1\n                    if in_degree[v_id] == 0:\n                        queue.append(v_id)\n                # else: # If v_id is not in in_degree, it may be a dependency not defined in nodes\n                #     pass # validate_graph should have already captured this situation\n\n        if len(topological_order) != len(self.nodes):\n            # Find the nodes in the cycle (optional, more complex)\n            # For simplicity, only report the existence of a cycle\n            raise ValueError(f\"There are circular dependencies in graph '{self.graph_id}', and topological sorting cannot be performed.\")\n\n        return topological_order\n\n    def reset_nodes_status(self) -> None:\n        \"\"\"Reset the status of all nodes in the graph to PENDING, and clear the output and error information.\"\"\"\n        for node in self.nodes.values():\n            node.status = NodeStatus.PENDING\n            node.output = None\n            node.error_info = None\n            node.retries_done = 0\n\n    @classmethod\n    def load_from_config(cls, graph_id: str, config_data: List[Dict[str, Any]]) -> \"TaskGraph\":\n        \"\"\"\n        Create a TaskGraph from configuration data (e.g., parsed from a YAML/JSON file).\n        The configuration data should be a list of dictionaries, each describing a node.\n\n        Parameters:\n            graph_id (str): The ID of the graph.\n            config_data (List[Dict[str, Any]]): A list of node configurations.\n                Each dictionary should contain: 'node_id', 'node_type' (in string form),\n                Optional: 'dependencies', 'config', 'executable_ref', 'node_role', 'retry_limit'.\n\n        Returns:\n            TaskGraph: The constructed task graph object.\n        \"\"\"\n        graph = cls(graph_id)\n        for node_conf in config_data:\n            try:\n                node_type_str = node_conf.get(\"node_type\")\n                if not node_type_str:\n                    raise ValueError(f\"Node configuration {node_conf.get('node_id', 'Unknown ID')} is missing 'node_type'.\")\n\n                node_type = NodeType[node_type_str.upper()]\n\n                node_role_str = node_conf.get(\"node_role\")\n                node_role = NodeRole[node_role_str.upper()] if node_role_str else NodeRole.DEFAULT\n                only_forward_compute = node_conf.get(\"only_forward_compute\", False)\n                agent_group = node_conf.get(\"agent_group\", 0)\n\n                node = Node(\n                    node_id=node_conf[\"node_id\"],\n                    node_type=node_type,\n                    node_role=node_role,\n                    only_forward_compute=only_forward_compute,\n                    agent_group=agent_group,\n                    dependencies=node_conf.get(\"dependencies\"),\n                    config=node_conf.get(\"config\"),\n                    executable_ref=node_conf.get(\"executable_ref\"),\n                    retry_limit=node_conf.get(\"retry_limit\", 0),\n                )\n                graph.add_node(node)\n            except KeyError as e:\n                raise ValueError(f\"Node configuration {node_conf.get('node_id', 'Unknown ID')} is missing required field: {e}\")\n            except ValueError as e:  # e.g., NodeType['INVALID']\n                raise ValueError(f\"Node configuration {node_conf.get('node_id', 'Unknown ID')} has a value error: {e}\")\n\n        graph.build_adjacency_lists()  # Ensure the adjacency list is built after all nodes are added\n        valid, msg = graph.validate_graph()\n        if not valid:\n            raise ValueError(f\"The graph loaded from the configuration is invalid: {msg}\")\n        return graph\n\n    def __repr__(self) -> str:\n        \"\"\"\n        Return the topological structure of the DAG in text symbol form.\n        Returns:\n            str: A string representation of the DAG's topological structure.\n        \"\"\"\n        output_lines = [f\"TaskGraph(graph_id='{self.graph_id}', num_nodes={len(self.nodes)})\"]\n        if not self.nodes:\n            return \"\\n\".join(output_lines)\n\n        try:\n            # get_topological_sort will call build_adjacency_lists internally\n            processing_order = self.get_topological_sort()\n        except ValueError as e:  # Capture circular dependency errors\n            return f\"Unable to display DAG graph '{self.graph_id}': {e}\"\n\n        output_lines.append(\"=\" * (len(output_lines[0])))\n\n        for node_id in processing_order:\n            node = self.nodes[node_id]\n            output_lines.append(f\"[{node.node_id}] ({node.node_type.value})\")\n\n            if node.executable_ref:\n                output_lines.append(f\"  Executable Ref: {node.executable_ref}\")\n            if node.config:\n                output_lines.append(f\"  Config: {node.config}\")\n            output_lines.append(f\"        {node}\")\n\n            # Display upstream dependencies (parent nodes)\n            parent_ids = sorted(self.rev_adj.get(node_id, []))\n            if parent_ids:\n                output_lines.append(f\"  ↑ (Depends on upstream)\")\n                for parent_id in parent_ids:\n                    output_lines.append(f\"    ↖── [{parent_id}]\")\n            elif not parent_ids:  # It's an entry node\n                output_lines.append(\"  (Entry node)\")\n\n            # Display downstream execution (child nodes)\n            children_ids = sorted(self.adj.get(node_id, []))\n            if children_ids:\n                # output_lines.append(f\"  ↓ (Subsequent execution)\") # Optional title line\n                for i, child_id in enumerate(children_ids):\n                    child_node = self.nodes.get(child_id)  # The child node should exist\n                    connector = \"  └─→ \" if i == len(children_ids) - 1 else \"  ├─→ \"\n                    output_lines.append(f\"{connector}[{child_id}] ({child_node.node_type.value if child_node else 'Unknown type'})\")\n            elif not children_ids:  # It's an exit node\n                output_lines.append(\"  (Exit node)\")\n\n            output_lines.append(\"\")  # Add a blank line after each node block for readability\n\n        return \"\\n\".join(output_lines).strip()\n\n    def copy(self) -> \"TaskGraph\":\n        new_graph = TaskGraph(graph_id=f\"{self.graph_id}_copy\")\n        for _, original_node in self.nodes.items():\n            new_graph.add_node(original_node.copy())\n        new_graph.build_adjacency_lists()\n        return new_graph\n\n    def save_dag_pic(self, filename: str = \"task_graph\", directory: Optional[str] = None, view: bool = False, cleanup: bool = True) -> Optional[str]:\n        \"\"\"\n        Visualize the DAG as an image using graphviz and save it.\n\n        Parameters:\n            filename (str): The file name of the output image (without the extension, e.g., \"dag_pic\").\n                            The final file name will be filename.format (e.g., dag_pic.png).\n            directory (Optional[str]): The directory to save the image.\n                If None, it will be saved in the current working directory.\n            view (bool): Whether to automatically open the image after generation.\n            cleanup (bool): Whether to delete the temporary DOT source file after rendering.\n\n        Returns:\n            Optional[str]: The full path of the image if successful, otherwise None.\n        \"\"\"\n        from graphviz import Digraph\n\n        if not self.nodes:\n            logger.warning(f\"DAG graph '{self.graph_id}' is empty. No image will be generated.\")\n            return None\n\n        # Ensure the adjacency list is up-to-date\n        self.build_adjacency_lists()\n\n        # Check the validity of the graph, e.g., if there are cycles\n        is_valid, error_msg = self.validate_graph()\n        if not is_valid:\n            logger.error(f\"Graph '{self.graph_id}' is invalid. Unable to generate image: {error_msg}\")\n            return None\n\n        dot = Digraph(name=self.graph_id, comment=f\"DAG for {self.graph_id}\")\n        dot.attr(rankdir=\"TB\")  # Layout from top to bottom (optional LR: from left to right)\n        dot.attr(label=f\"DAG: {self.graph_id}\", fontsize=\"20\")\n        dot.attr(labelloc=\"t\")  # Title position at the top\n\n        for node_id, node in self.nodes.items():\n            # Node label contains ID, type, and role\n            label = f\"{node.node_id}\\n({node.node_type.value})\"\n            if node.node_role:\n                label += f\"\\nRole: {node.node_role.value}\"\n\n            colors = {NodeType.MODEL_TRAIN: \"blue\", NodeType.MODEL_INFERENCE: \"green\", NodeType.DATA_LOAD: \"orange\", NodeType.BARRIER_SYNC: \"red\"}\n            color = colors.get(node.node_type, \"black\")\n            node_attrs = {\"penwidth\": \"2.0\", \"color\": color, \"fontcolor\": color}\n\n            dot.node(node_id, label=label, **node_attrs)\n\n        # Add edges\n        for node_id, children_ids in self.adj.items():\n            # Skip dependency nodes not defined in the graph (theoretically should not happen)\n            if node_id not in self.nodes:\n                continue\n            source_node = self.nodes[node_id]\n            for child_id in children_ids:\n                if child_id not in self.nodes:  # Skip child nodes not defined in the graph\n                    continue\n                # target_node = self.nodes[child_id]\n                # Default edge color\n                edge_color = \"black\"\n                # If necessary, change the edge color according to the source or target node type, e.g.:\n                # if source_node.node_type == NodeType.BARRIER_SYNC or target_node.node_type == NodeType.BARRIER_SYNC:\n                #     edge_color = \"red\" # If you want the edges connected to Barrier to be red too\n                dot.edge(node_id, child_id, color=edge_color)\n\n        try:\n            # Construct the full file path\n            output_path = os.path.join(directory or \".\", filename)\n            rendered_path = dot.render(filename=output_path, directory=None, view=view, cleanup=cleanup, format=\"svg\")\n            logger.info(f\"DAG image saved to: {rendered_path}\")\n            return rendered_path\n        except Exception as e:\n            logger.error(f\"An error occurred while generating the DAG image: {e}. Please ensure the Graphviz executable is in your system PATH.\")\n            return None\n\n    def get_nodes_by_type(self, node_types: List[NodeType]) -> List[Node]:\n        \"\"\"\n        Retrieves all nodes from the graph that match any of the specified node types.\n\n        Args:\n            node_types: A list of NodeType enums to filter by.\n\n        Returns:\n            A list of Node objects whose type is in the node_types list.\n        \"\"\"\n        return [node for node in self.nodes.values() if node.node_type in node_types]\n\n    def get_nodes_by_role(self, node_role: NodeRole) -> List[Node]:\n        \"\"\"\n        Retrieves all nodes from the graph that match the specified node role.\n\n        Args:\n            node_role: A NodeRole enum to filter by.\n\n        Returns:\n            A list of Node objects whose role matches the specified node_role.\n        \"\"\"\n        return [node for node in self.nodes.values() if node.node_role == node_role]\n\n\n# Example usage:\nif __name__ == \"__main__\":\n    logger.info(\"--- Demonstration of the core class of the DAG module ---\")\n\n    node_a = Node(node_id=\"rollout_actor\", node_type=NodeType.MODEL_INFERENCE, node_role=NodeRole.ROLLOUT)\n    node_b = Node(node_id=\"B\", node_type=NodeType.MODEL_INFERENCE, node_role=NodeRole.ROLLOUT, dependencies=[\"A\"])\n    node_c = Node(node_id=\"C\", node_type=NodeType.MODEL_INFERENCE, node_role=NodeRole.REFERENCE, dependencies=[\"A\"])\n    node_d = Node(node_id=\"D\", node_type=NodeType.BARRIER_SYNC, dependencies=[\"B\", \"C\"])\n    node_e = Node(node_id=\"E\", node_type=NodeType.MODEL_TRAIN, node_role=NodeRole.ACTOR, dependencies=[\"D\"])\n\n    graph = TaskGraph(graph_id=\"example_rl_pipeline\")\n    graph.add_nodes([node_a, node_b, node_c, node_d, node_e])\n    graph.build_adjacency_lists()  # Ensure the adjacency list is built\n\n    is_valid, validation_msg = graph.validate_graph()\n    if is_valid:\n        logger.info(f\"\\nGraph '{graph.graph_id}' passed validation.\")\n    else:\n        logger.info(f\"\\nGraph '{graph.graph_id}' failed validation: {validation_msg}\")\n        exit(1)\n\n    logger.info(f\"{graph}\")\n    graph.save_dag_pic()\n\n    logger.info(\"\\n--- print_dag of a graph with multiple independent branches ---\")\n    multi_branch_graph = TaskGraph(\"multi_branch_dag\")\n    mb_n1 = Node(\"MB1\", NodeType.DATA_LOAD)\n    mb_n2 = Node(\"MB2\", NodeType.COMPUTE, dependencies=[\"MB1\"])\n    mb_n3 = Node(\"MB3\", NodeType.DATA_LOAD)  # Another entry node\n    mb_n4 = Node(\"MB4\", NodeType.COMPUTE, dependencies=[\"MB3\"])\n    mb_n5 = Node(\"MB5\", NodeType.MODEL_TRAIN, dependencies=[\"MB2\", \"MB4\"])  # Merge node\n    multi_branch_graph.add_node(mb_n1)\n    multi_branch_graph.add_node(mb_n2)\n    multi_branch_graph.add_node(mb_n3)\n    multi_branch_graph.add_node(mb_n4)\n    multi_branch_graph.add_node(mb_n5)\n    logger.info(multi_branch_graph)\n    multi_branch_graph.save_dag_pic(\"multi_branch\")\n\n    logger.info(\"\\n--- print_dag of an empty graph ---\")\n    empty_graph = TaskGraph(\"empty_graph_for_print_dag\")\n    logger.info(empty_graph)\n\n    logger.info(\"\\n--- print_dag of a graph with circular dependencies ---\")\n    cyclic_graph = TaskGraph(\"cyclic_graph_for_print_dag\")\n    cg_n1 = Node(\"CG1\", NodeType.COMPUTE, dependencies=[\"CG2\"])\n    cg_n2 = Node(\"CG2\", NodeType.COMPUTE, dependencies=[\"CG1\"])\n    cyclic_graph.add_node(cg_n1)\n    cyclic_graph.add_node(cg_n2)\n    logger.info(cyclic_graph)\n"
  },
  {
    "path": "siirl/execution/dag/task_loader.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport copy\nfrom typing import Dict, List, Optional, Tuple, Set\nimport itertools\n\nfrom loguru import logger\nfrom siirl.execution.dag import Node, NodeType, TaskGraph\n\n\ndef generate_structural_signature(graph: TaskGraph) -> str:\n    \"\"\"\n    Generates a canonical string signature for a TaskGraph based purely on its structure\n    (nodes, their types, roles, and connections), ignoring the graph_id.\n\n    Args:\n        graph (TaskGraph): The TaskGraph object for which to generate the signature.\n\n    Returns:\n        str: A unique string representing the structural signature of the graph.\n    \"\"\"\n    if not graph or not graph.nodes:\n        # Handle empty or invalid graphs\n        return f\"empty_structure_original_id_ref:{graph.graph_id}\"\n\n    # Ensure adjacency lists are built for accurate dependency representation\n    graph.build_adjacency_lists()\n\n    # Sort node IDs to ensure consistent signature generation regardless of insertion order\n    node_ids_sorted: List[str] = sorted(graph.nodes.keys())\n\n    node_details_parts: List[str] = []\n    # Collect details for each node, including its dependencies\n    for nid in node_ids_sorted:\n        node: Node = graph.nodes[nid]\n        # Include sorted dependencies to make the node's structural role explicit\n        sorted_deps: List[str] = sorted(list(node.dependencies))\n        node_details_parts.append(f\"n(id:{nid},t:{node.node_type.value},r:{node.node_role.value},d:[{','.join(sorted_deps)}])\")\n\n    # Explicitly list edges for robustness, based on the built adjacency list\n    edge_list_parts: List[str] = []\n    for parent_id in node_ids_sorted:\n        children_ids_sorted: List[str] = sorted(graph.adj.get(parent_id, []))\n        for child_id in children_ids_sorted:\n            edge_list_parts.append(f\"e({parent_id}->{child_id})\")\n\n    # Combine node and edge details into a single, canonical structural signature\n    return f\"struct_nodes:{';'.join(node_details_parts)}|struct_edges:{','.join(edge_list_parts)}\"\n\n\ndef get_all_downstream_nodes_recursive(src_task_graph: TaskGraph, start_node_id: str, visited: Optional[Set[str]] = None) -> Set[str]:\n    \"\"\"\n    Recursively finds all downstream nodes reachable from a given start node in a TaskGraph.\n\n    Args:\n        src_task_graph (TaskGraph): The source TaskGraph to traverse.\n        start_node_id (str): The ID of the starting node.\n        visited (Optional[Set[str]]): A set of visited nodes to prevent infinite loops in cycles (though DAGs shouldn't have them).\n\n    Returns:\n        Set[str]: A set of all downstream node IDs, including the start node itself.\n    \"\"\"\n    if visited is None:\n        visited = set()\n\n    # Base case: if node already visited or not in graph, return empty set\n    if start_node_id in visited or start_node_id not in src_task_graph.nodes:\n        return set()\n\n    # Mark current node as visited and add to downstream set\n    visited.add(start_node_id)\n    downstream: Set[str] = {start_node_id}\n\n    # Recursively find downstream nodes for each child\n    for child_id in src_task_graph.adj.get(start_node_id, []):\n        downstream.update(get_all_downstream_nodes_recursive(src_task_graph, child_id, visited.copy()))\n    return downstream\n\n\ndef get_all_ancestors(graph: TaskGraph, node_id: str) -> Set[str]:\n    \"\"\"\n    Finds all ancestor nodes of a given node in a TaskGraph.\n\n    Args:\n        graph (TaskGraph): The TaskGraph to traverse.\n        node_id (str): The ID of the node for which to find ancestors.\n\n    Returns:\n        Set[str]: A set of all ancestor node IDs.\n    \"\"\"\n    if node_id not in graph.nodes:\n        return set()\n\n    # Ensure reverse adjacency list is built for efficient ancestor traversal\n    graph.build_adjacency_lists()\n\n    ancestors: Set[str] = set()\n    # Initialize queue with immediate parents\n    queue: List[str] = list(graph.rev_adj.get(node_id, []))\n    visited_ancestors: Set[str] = set(queue)\n    ancestors.update(queue)\n\n    head = 0\n    # BFS traversal to find all ancestors\n    while head < len(queue):\n        current_node_id: str = queue[head]\n        head += 1\n        for parent_id in graph.rev_adj.get(current_node_id, []):\n            if parent_id not in visited_ancestors:\n                visited_ancestors.add(parent_id)\n                ancestors.add(parent_id)\n                queue.append(parent_id)\n    return ancestors\n\n\ndef find_all_paths_dfs(src_task_graph: TaskGraph, current_node_id: str, end_node_id: str, current_path: List[str], all_paths: List[List[str]], visited_in_current_path: Set[str]):\n    \"\"\"\n    Helper function for find_all_paths, uses DFS to find all paths between two nodes.\n\n    Args:\n        src_task_graph (TaskGraph): The source TaskGraph.\n        current_node_id (str): The ID of the current node in the DFS traversal.\n        end_node_id (str): The ID of the target end node.\n        current_path (List[str]): The path built so far from the start node to current_node_id.\n        all_paths (List[List[str]]): A list to accumulate all found paths.\n        visited_in_current_path (Set[str]): Nodes visited in the current path to detect cycles (if graph were not a DAG).\n    \"\"\"\n    current_path.append(current_node_id)\n    visited_in_current_path.add(current_node_id)\n\n    if current_node_id == end_node_id:\n        # If end node reached, add a copy of the current path to all_paths\n        all_paths.append(list(current_path))\n    else:\n        # Explore neighbors\n        for neighbor_id in sorted(src_task_graph.adj.get(current_node_id, [])):\n            if neighbor_id not in visited_in_current_path:  # Prevent cycles\n                find_all_paths_dfs(src_task_graph, neighbor_id, end_node_id, current_path, all_paths, visited_in_current_path)\n    # Backtrack: remove current node from path and visited set\n    current_path.pop()\n    visited_in_current_path.remove(current_node_id)\n\n\ndef find_all_paths(src_task_graph: TaskGraph, start_node_id: str, end_node_id: str) -> List[List[str]]:\n    \"\"\"\n    Finds all simple paths from a start node to an end node in a TaskGraph.\n\n    Args:\n        src_task_graph (TaskGraph): The TaskGraph to search within.\n        start_node_id (str): The ID of the starting node.\n        end_node_id (str): The ID of the ending node.\n\n    Returns:\n        List[List[str]]: A list of lists, where each inner list represents a path\n                         as a sequence of node IDs.\n    \"\"\"\n    if start_node_id not in src_task_graph.nodes or end_node_id not in src_task_graph.nodes:\n        logger.warning(f\"Pathfinding: Start '{start_node_id}' or end '{end_node_id}' not in graph '{src_task_graph.graph_id}'.\")\n        return []\n\n    all_paths: List[List[str]] = []\n    src_task_graph.build_adjacency_lists()  # Ensure adj lists are ready for DFS\n    find_all_paths_dfs(src_task_graph, start_node_id, end_node_id, [], all_paths, set())\n    return all_paths\n\n\ndef split_single_structure(src_task_graph: TaskGraph, parallel_branch_node_lists: List[List[str]], merge_node_id: str, base_subgraph_idx_str: str) -> List[TaskGraph]:\n    \"\"\"\n    Splits a graph based on identified parallel branches converging at a merge node.\n    Each branch, along with its common upstream nodes and the common downstream nodes\n    starting from the merge node, forms a new subgraph.\n\n    Args:\n        src_task_graph (TaskGraph): The original TaskGraph to split.\n        parallel_branch_node_lists (List[List[str]]): A list of lists, where each inner list\n                                                      represents the sequence of nodes for a specific branch\n                                                      leading up to the merge node (exclusive of merge node).\n        merge_node_id (str): The ID of the node where the parallel branches re-converge.\n        base_subgraph_idx_str (str): A string prefix for naming the generated subgraphs.\n\n    Returns:\n        List[TaskGraph]: A list of new TaskGraph objects, each representing one of the split branches.\n    \"\"\"\n    if not src_task_graph.nodes:\n        return []\n    if merge_node_id not in src_task_graph.nodes:\n        return []\n\n    src_task_graph.build_adjacency_lists()\n\n    # Identify all nodes common to all branches from the merge node downwards\n    common_downstream_nodes_ids: Set[str] = get_all_downstream_nodes_recursive(src_task_graph, merge_node_id, visited=set())\n    if not common_downstream_nodes_ids:\n        common_downstream_nodes_ids = {merge_node_id}  # If merge node has no children, it's still part of the common suffix\n\n    created_subgraphs: List[TaskGraph] = []\n    # Create a subgraph for each identified parallel branch\n    for i, branch_nodes_prefix in enumerate(parallel_branch_node_lists):\n        subgraph_id: str = f\"{src_task_graph.graph_id}_{base_subgraph_idx_str}_b{i + 1}\"\n        subgraph: TaskGraph = TaskGraph(graph_id=subgraph_id)\n        current_subgraph_node_ids: Set[str] = set()\n\n        # Add nodes from the unique part of the branch\n        for node_id in branch_nodes_prefix:\n            if node_id == merge_node_id:\n                continue  # Merge node is added with common_downstream_nodes_ids\n            if node_id not in src_task_graph.nodes:\n                current_subgraph_node_ids.clear()\n                break  # Invalid node in branch, skip this subgraph\n            subgraph.add_node(copy.deepcopy(src_task_graph.nodes[node_id]))\n            current_subgraph_node_ids.add(node_id)\n\n        if not current_subgraph_node_ids and branch_nodes_prefix:\n            continue  # If branch_nodes_prefix was not empty but no valid nodes were added\n\n        # Add common downstream nodes (including the merge node) to the subgraph\n        for node_id in common_downstream_nodes_ids:\n            if node_id not in src_task_graph.nodes:\n                current_subgraph_node_ids.clear()\n                break  # Invalid node, skip this subgraph\n            if node_id not in current_subgraph_node_ids:\n                subgraph.add_node(copy.deepcopy(src_task_graph.nodes[node_id]))\n                current_subgraph_node_ids.add(node_id)\n\n        if not current_subgraph_node_ids:\n            continue  # Skip if no nodes were added to the subgraph\n\n        # Adjust dependencies for nodes within the new subgraph to only refer to other nodes in the same subgraph\n        valid_subgraph_nodes: Set[str] = set(subgraph.nodes.keys())\n        for sg_node_id in list(subgraph.nodes.keys()):\n            original_node: Optional[Node] = src_task_graph.nodes.get(sg_node_id)\n            if not original_node:\n                continue\n            new_deps: List[str] = [dep for dep in original_node.dependencies if dep in valid_subgraph_nodes]\n            subgraph.nodes[sg_node_id].dependencies = new_deps\n\n        subgraph.build_adjacency_lists()  # Rebuild adj lists for the new subgraph\n        is_valid, msg = subgraph.validate_graph()\n\n        if is_valid and subgraph.nodes:\n            created_subgraphs.append(subgraph)\n        elif subgraph.nodes:\n            logger.error(f\"Invalid reconverge subgraph '{subgraph.graph_id}': {msg}.\")\n        else:\n            logger.warning(f\"Empty reconverge subgraph '{subgraph.graph_id}'.\")\n\n    return created_subgraphs\n\n\ndef split_by_fan_out_to_exits(src_task_graph: TaskGraph, naming_prefix_idx: int) -> List[TaskGraph]:\n    \"\"\"\n    Attempts to split a TaskGraph if it contains a fan-out node leading to multiple\n    distinct exit nodes that do not re-converge.\n\n    Args:\n        src_task_graph (TaskGraph): The TaskGraph to analyze and potentially split.\n        naming_prefix_idx (int): An index used for unique naming of generated subgraphs.\n\n    Returns:\n        List[TaskGraph]: A list of new TaskGraph objects if a split occurs, otherwise an empty list.\n    \"\"\"\n    src_task_graph.build_adjacency_lists()\n    if not src_task_graph.nodes:\n        return []\n\n    is_valid, msg = src_task_graph.validate_graph()  # Validate before proceeding\n    if not is_valid:\n        logger.error(f\"Fan-out: Invalid graph '{src_task_graph.graph_id}': {msg}.\")\n        return []\n\n    original_exit_node_ids: Set[str] = {n.node_id for n in src_task_graph.get_exit_nodes()}\n    if len(original_exit_node_ids) <= 1:\n        return []  # No fan-out to multiple distinct exits possible if only one or no exits\n\n    # Iterate through each node to find potential fork points\n    for fork_candidate_id in sorted(list(src_task_graph.nodes.keys())):\n        children_ids: List[str] = sorted(list(src_task_graph.adj.get(fork_candidate_id, [])))\n        if len(children_ids) < 2:\n            continue  # A fork point must have at least two children\n\n        child_branch_details: List[Tuple[str, Set[str], Set[str]]] = []\n        # For each child of the fork, determine its downstream nodes and reachable exit nodes\n        for child_id in children_ids:\n            downstream_of_child: Set[str] = get_all_downstream_nodes_recursive(src_task_graph, child_id, visited=set())\n            reachable_exits: Set[str] = downstream_of_child.intersection(original_exit_node_ids)\n            if reachable_exits:\n                child_branch_details.append((child_id, downstream_of_child, reachable_exits))\n\n        if len(child_branch_details) < 2:\n            continue  # Need at least two branches with reachable exits\n\n        # Check combinations of branches to find pairwise disjoint exit sets\n        for r in range(len(child_branch_details), 1, -1):  # Start with largest combos, then smaller\n            for combo in itertools.combinations(child_branch_details, r):\n                exit_sets_in_combo: List[Set[str]] = [details[2] for details in combo]\n                is_pairwise_disjoint: bool = True\n                temp_union_of_exits: Set[str] = set()\n                # Verify that the exit nodes for the chosen branches are mutually exclusive\n                for exits_set in exit_sets_in_combo:\n                    if not exits_set.isdisjoint(temp_union_of_exits):\n                        is_pairwise_disjoint = False\n                        break\n                    temp_union_of_exits.update(exits_set)\n\n                if is_pairwise_disjoint and len(temp_union_of_exits) >= r:  # Ensure there are at least 'r' distinct exits\n                    # Collect all ancestors of the fork node and the fork node itself\n                    ancestors_of_fork: Set[str] = get_all_ancestors(src_task_graph, fork_candidate_id)\n                    common_upstream_nodes: Set[str] = ancestors_of_fork | {fork_candidate_id}\n\n                    current_split_generated_graphs: List[TaskGraph] = []\n                    # Create a new subgraph for each branch in the disjoint combo\n                    for i, (child_c, downstream_nodes_c, exits_c) in enumerate(combo):\n                        subgraph_node_ids: Set[str] = common_upstream_nodes | downstream_nodes_c\n                        subgraph: TaskGraph = TaskGraph(graph_id=f\"{src_task_graph.graph_id}_fan{naming_prefix_idx}_{fork_candidate_id}_b{i + 1}\")\n\n                        # Add relevant nodes to the new subgraph\n                        for node_id_to_add in subgraph_node_ids:\n                            if node_id_to_add in src_task_graph.nodes:\n                                subgraph.add_node(copy.deepcopy(src_task_graph.nodes[node_id_to_add]))\n\n                        if not subgraph.nodes:\n                            continue\n\n                        # Adjust dependencies within the new subgraph\n                        for sg_node_id_adj in list(subgraph.nodes.keys()):\n                            original_node_deps: List[str] = src_task_graph.nodes[sg_node_id_adj].dependencies\n                            new_deps_adj: List[str] = [dep for dep in original_node_deps if dep in subgraph.nodes]\n                            subgraph.nodes[sg_node_id_adj].dependencies = new_deps_adj\n\n                        subgraph.build_adjacency_lists()  # Rebuild adj lists for the new subgraph\n                        is_sg_valid, sg_msg = subgraph.validate_graph()\n\n                        if is_sg_valid and subgraph.nodes:\n                            current_split_generated_graphs.append(subgraph)\n                        elif subgraph.nodes:\n                            logger.error(f\"Invalid fanout subgraph '{subgraph.graph_id}': {sg_msg}.\")\n\n                    if current_split_generated_graphs:  # If any valid graphs were made for this combo\n                        logger.info(f\"Fan-out split at '{fork_candidate_id}' (graph '{src_task_graph.graph_id}') -> {len(current_split_generated_graphs)} subgraphs.\")\n                        return current_split_generated_graphs  # Return the first successful split found\n    return []  # No fan-out split found for the entire graph\n\n\ndef split_by_reconverging_paths(src_task_graph: TaskGraph, naming_prefix_idx: int) -> List[TaskGraph]:\n    \"\"\"\n    Attempts to split a TaskGraph if it contains re-converging parallel paths.\n    It identifies common merge points and splits the graph into subgraphs,\n    each representing a unique path leading to the merge point plus the common suffix.\n\n    Args:\n        src_task_graph (TaskGraph): The TaskGraph to analyze and potentially split.\n        naming_prefix_idx (int): An index used for unique naming of generated subgraphs.\n\n    Returns:\n        List[TaskGraph]: A list of new TaskGraph objects if a split occurs, otherwise an empty list.\n    \"\"\"\n    src_task_graph.build_adjacency_lists()\n    if not src_task_graph.nodes:\n        return []\n\n    is_valid, msg = src_task_graph.validate_graph()\n    if not is_valid:\n        logger.error(f\"Reconv: Invalid graph '{src_task_graph.graph_id}': {msg}.\")\n        return []\n\n    entry_node_ids: List[str] = [n.node_id for n in src_task_graph.get_entry_nodes()]\n    exit_node_ids: List[str] = [n.node_id for n in src_task_graph.get_exit_nodes()]\n    if not entry_node_ids or not exit_node_ids:\n        return []\n\n    all_e2e_paths: List[List[str]] = []\n    path_tuples_seen: Set[Tuple[str, ...]] = set()\n    # Find all unique end-to-end paths in the graph\n    for e_id in entry_node_ids:\n        for x_id in exit_node_ids:\n            paths: List[List[str]] = find_all_paths(src_task_graph, e_id, x_id)\n            for p in paths:\n                pt: Tuple[str, ...] = tuple(p)\n                if pt not in path_tuples_seen:\n                    all_e2e_paths.append(p)\n                    path_tuples_seen.add(pt)\n\n    if len(all_e2e_paths) <= 1:\n        return []  # No parallel paths to consider for reconvergence\n\n    found_split_candidates: List[Tuple[List[List[str]], str]] = []  # Stores (list of branch node sequences, merge node ID)\n    # Compare all pairs of paths to find common suffixes and diverging prefixes\n    for i in range(len(all_e2e_paths)):\n        for j in range(i + 1, len(all_e2e_paths)):\n            path1, path2 = all_e2e_paths[i], all_e2e_paths[j]\n            merge_idx_p1, merge_idx_p2, first_common_node_in_suffix = -1, -1, None\n            # Traverse paths backward to find the first common node (merge point)\n            for k_idx_from_end in range(min(len(path1), len(path2))):\n                node_p1: str = path1[len(path1) - 1 - k_idx_from_end]\n                node_p2: str = path2[len(path2) - 1 - k_idx_from_end]\n                if node_p1 == node_p2:\n                    first_common_node_in_suffix = node_p1\n                    merge_idx_p1, merge_idx_p2 = len(path1) - 1 - k_idx_from_end, len(path2) - 1 - k_idx_from_end\n                else:\n                    break  # Paths diverge\n\n            # If a common merge node is found and branches are not empty/identical\n            if first_common_node_in_suffix is not None and merge_idx_p1 > 0 and merge_idx_p2 > 0:\n                branch1_nodes: List[str] = path1[:merge_idx_p1]\n                branch2_nodes: List[str] = path2[:merge_idx_p2]\n                if not branch1_nodes or not branch2_nodes or tuple(branch1_nodes) == tuple(branch2_nodes):\n                    continue\n\n                # Group branches that merge into the same node\n                is_new_candidate_group: bool = True\n                for cand_idx, (existing_branches_list, existing_merge_node) in enumerate(found_split_candidates):\n                    if existing_merge_node == first_common_node_in_suffix:\n                        is_new_candidate_group = False\n                        current_branch_tuples: Set[Tuple[str, ...]] = {tuple(b) for b in existing_branches_list}\n                        if tuple(branch1_nodes) not in current_branch_tuples:\n                            existing_branches_list.append(branch1_nodes)\n                        if tuple(branch2_nodes) not in current_branch_tuples:\n                            existing_branches_list.append(branch2_nodes)\n                        break\n                if is_new_candidate_group:\n                    found_split_candidates.append(([branch1_nodes, branch2_nodes], first_common_node_in_suffix))\n\n    if not found_split_candidates:\n        return []\n\n    all_resulting_subgraphs: List[TaskGraph] = []\n    processed_split_signatures_for_subgraph_gen: Set[Tuple[Tuple[Tuple[str, ...], ...], str]] = set()\n\n    # Process each found split candidate to generate subgraphs\n    for split_idx, (branch_definitions, merge_node) in enumerate(found_split_candidates):\n        if len(branch_definitions) < 2:\n            continue  # Need at least two branches for a split\n\n        # Sort branches for canonical signature to avoid redundant processing\n        sorted_branch_tuples: Tuple[Tuple[str, ...], ...] = tuple(sorted([tuple(b) for b in branch_definitions]))\n        current_split_signature: Tuple[Tuple[Tuple[str, ...], ...], str] = (sorted_branch_tuples, merge_node)\n\n        if current_split_signature in processed_split_signatures_for_subgraph_gen:\n            continue\n        processed_split_signatures_for_subgraph_gen.add(current_split_signature)\n\n        naming_str: str = f\"reconv{naming_prefix_idx}_{merge_node}_s{split_idx}\"\n        subgraphs_from_this_split: List[TaskGraph] = split_single_structure(src_task_graph, branch_definitions, merge_node, naming_str)\n        if subgraphs_from_this_split:  # Only extend if non-empty\n            logger.info(f\"Re-converge split for merge '{merge_node}' (graph '{src_task_graph.graph_id}') -> {len(subgraphs_from_this_split)} subgraphs.\")\n            all_resulting_subgraphs.extend(subgraphs_from_this_split)\n\n    return all_resulting_subgraphs\n\n\ndef discover_and_split_parallel_paths(src_task_graph: TaskGraph) -> List[TaskGraph]:\n    \"\"\"\n    Discovers and splits a TaskGraph into irreducible subgraphs by iteratively identifying\n    and splitting fan-out and re-converging parallel paths.\n\n    Args:\n        src_task_graph (TaskGraph): The original TaskGraph to be analyzed and split.\n\n    Returns:\n        List[TaskGraph]: A list of TaskGraph objects, where each represents an\n                         irreducible (cannot be further split by these rules) subgraph.\n    \"\"\"\n    if not src_task_graph or not src_task_graph.nodes:\n        logger.info(\"Input graph is empty. Nothing to split.\")\n        return []\n\n    initial_graph_copy: TaskGraph = src_task_graph.copy() if hasattr(src_task_graph, \"copy\") else copy.deepcopy(src_task_graph)\n    initial_graph_copy.build_adjacency_lists()  # Ensure it's ready for validation\n    is_valid, msg = initial_graph_copy.validate_graph()\n    if not is_valid:\n        logger.error(f\"Original graph '{initial_graph_copy.graph_id}' is invalid: {msg}. Cannot split.\")\n        return [initial_graph_copy]\n\n    final_irreducible_graphs: List[TaskGraph] = []\n    processing_queue: List[TaskGraph] = [initial_graph_copy]\n    # Use structural signature to track processed graphs and avoid redundant work on identical structures\n    processed_structural_signatures_in_queue: Set[str] = set()\n\n    iteration_counter: int = 0  # For unique naming of intermediate graphs\n\n    # Process graphs in a queue until no more splits can be made\n    while processing_queue:\n        current_graph: TaskGraph = processing_queue.pop(0)\n        current_graph_structural_sig: str = generate_structural_signature(current_graph)\n\n        if current_graph_structural_sig in processed_structural_signatures_in_queue:\n            logger.debug(f\"Skipping already processed graph structure (sig: {current_graph_structural_sig[:70]}...) for graph ID {current_graph.graph_id}\")\n            continue\n        # Add to processed signatures AFTER successful splitting attempt (or if it's irreducible)\n        # processed_structural_signatures_in_queue.add(current_graph_structural_sig) # Moved this to after potential splits\n\n        iteration_counter += 1\n        split_occurred_this_pass: bool = False\n\n        # 1. Try to split by fan-out to distinct exits\n        graphs_after_fan_out: List[TaskGraph] = split_by_fan_out_to_exits(current_graph, iteration_counter)\n        if graphs_after_fan_out:  # Non-empty list means split occurred\n            for g_fan in graphs_after_fan_out:\n                # Add newly created subgraphs to the queue for further processing\n                processing_queue.append(g_fan)\n            split_occurred_this_pass = True\n\n        if split_occurred_this_pass:\n            logger.debug(f\"Graph '{current_graph.graph_id}' processed with fan-out. Re-evaluating queue.\")\n            # If a split occurred, the current graph is replaced by its subgraphs.\n            # Mark its structural signature as processed so it's not re-added\n            processed_structural_signatures_in_queue.add(current_graph_structural_sig)\n            continue  # Continue to next graph in queue\n\n        # 2. If no fan-out split, try to split by re-converging paths\n        graphs_after_reconverge: List[TaskGraph] = split_by_reconverging_paths(current_graph, iteration_counter)\n        if graphs_after_reconverge:  # Non-empty list means split occurred\n            for g_reconv in graphs_after_reconverge:\n                # Add newly created subgraphs to the queue for further processing\n                processing_queue.append(g_reconv)\n            split_occurred_this_pass = True\n\n        if split_occurred_this_pass:\n            logger.debug(f\"Graph '{current_graph.graph_id}' processed with re-convergence. Re-evaluating queue.\")\n            # If a split occurred, the current graph is replaced by its subgraphs.\n            # Mark its structural signature as processed so it's not re-added\n            processed_structural_signatures_in_queue.add(current_graph_structural_sig)\n            continue  # Continue to next graph in queue\n\n        # If no split of any type happened on current_graph, it's irreducible\n        if not split_occurred_this_pass:\n            logger.info(f\"Graph '{current_graph.graph_id}' (struct_sig: {current_graph_structural_sig[:70]}...) is irreducible. Adding to final list.\")\n            final_irreducible_graphs.append(current_graph)\n            # Mark its structural signature as processed only when it is declared irreducible\n            processed_structural_signatures_in_queue.add(current_graph_structural_sig)\n\n    # Deduplicate final list of graphs based on structure and assign canonical names\n    unique_final_graphs_map: Dict[str, TaskGraph] = {}\n    true_final_graphs: List[TaskGraph] = []\n    base_id_for_final_naming: str = src_task_graph.graph_id\n\n    for g in final_irreducible_graphs:\n        structural_sig: str = generate_structural_signature(g)\n        if structural_sig not in unique_final_graphs_map:\n            new_final_id: str = f\"{base_id_for_final_naming}_split_{len(unique_final_graphs_map) + 1}\"\n            g.graph_id = new_final_id  # Update graph_id for the returned unique graph\n            unique_final_graphs_map[structural_sig] = g\n            true_final_graphs.append(g)\n        else:\n            logger.debug(f\"Skipping structurally duplicate final graph: current id {g.graph_id} (already found as {unique_final_graphs_map[structural_sig].graph_id}).\")\n\n    logger.info(f\"Original graph '{src_task_graph.graph_id}' resulted in {len(true_final_graphs)} unique irreducible TaskGraph(s).\")\n    return true_final_graphs\n\n\nif __name__ == \"__main__\":\n    # Ensure a directory for DAG images exists\n    if not os.path.exists(\"dag_images\"):\n        os.makedirs(\"dag_images\")\n\n    # Example 1: Re-converging paths\n    logger.info(f\"\\n--- Splitting graph: ex1_reconverge ---\")\n    node_a = Node(node_id=\"A\", node_type=NodeType.DATA_LOAD)\n    node_b = Node(node_id=\"B\", node_type=NodeType.COMPUTE, dependencies=[\"A\"])\n    node_a1 = Node(node_id=\"A1\", node_type=NodeType.DATA_LOAD)\n    node_b1 = Node(node_id=\"B1\", node_type=NodeType.COMPUTE, dependencies=[\"A1\"])\n    node_c = Node(node_id=\"C\", node_type=NodeType.COMPUTE, dependencies=[\"B\", \"B1\"])  # Re-convergence point\n    node_d_ex1 = Node(node_id=\"D_ex1\", node_type=NodeType.COMPUTE, dependencies=[\"C\"])\n    node_e_ex1 = Node(node_id=\"E_ex1\", node_type=NodeType.MODEL_TRAIN, dependencies=[\"D_ex1\"])\n\n    original_graph_ex1 = TaskGraph(graph_id=\"ex1_reconverge\")\n    original_graph_ex1.add_nodes([node_a, node_b, node_a1, node_b1, node_c, node_d_ex1, node_e_ex1])\n    if original_graph_ex1.nodes:\n        original_graph_ex1.save_dag_pic(filename=original_graph_ex1.graph_id + \"_orig_pic\", directory=\"dag_images\")\n\n    split_graphs1 = discover_and_split_parallel_paths(original_graph_ex1)\n    for idx, sg in enumerate(split_graphs1):\n        logger.info(f\"Final Subgraph {idx + 1}: {sg.graph_id} with {len(sg.nodes)} nodes.\")\n        if sg.nodes:\n            sg.save_dag_pic(filename=sg.graph_id + \"_pic\", directory=\"dag_images\")\n    logger.info(f\"--- Finished Ex1 ---\\n\")\n\n    # Example 2: Complex (fan-out and re-converging)\n    logger.info(f\"\\n--- Splitting graph: ex2_complex ---\")\n    node_x = Node(\"X\", NodeType.DATA_LOAD)\n    node_y = Node(\"Y\", NodeType.DATA_LOAD)\n    node_p1 = Node(\"P1\", NodeType.COMPUTE, dependencies=[\"X\"])\n    node_p2 = Node(\"P2\", NodeType.COMPUTE, dependencies=[\"Y\"])\n    node_m1 = Node(\"M1\", NodeType.COMPUTE, dependencies=[\"P1\", \"P2\"])  # Re-convergence 1\n    node_p3 = Node(\"P3\", NodeType.DATA_LOAD)\n    node_z = Node(\"Z\", NodeType.COMPUTE, dependencies=[\"M1\", \"P3\"])  # Re-convergence 2 (Z depends on P3 and M1 which itself is a merge)\n    node_j1 = Node(\"J1\", NodeType.COMPUTE, dependencies=[\"Z\"])\n    node_j2 = Node(\"J2\", NodeType.COMPUTE, dependencies=[\"Z\"])\n    node_k1 = Node(\"K1\", NodeType.MODEL_TRAIN, dependencies=[\"J1\"])  # Exit 1\n    node_k2 = Node(\"K2\", NodeType.MODEL_TRAIN, dependencies=[\"J2\"])  # Exit 2 (Fan-out at Z to J1, J2 leading to distinct exits K1, K2)\n\n    complex_graph_ex2 = TaskGraph(graph_id=\"ex2_complex\")\n    complex_graph_ex2.add_nodes([node_x, node_y, node_p1, node_p2, node_m1, node_p3, node_z, node_j1, node_j2, node_k1, node_k2])\n    if complex_graph_ex2.nodes:\n        complex_graph_ex2.save_dag_pic(filename=complex_graph_ex2.graph_id + \"_orig_pic\", directory=\"dag_images\")\n\n    split_graphs2 = discover_and_split_parallel_paths(complex_graph_ex2)\n    for idx, sg in enumerate(split_graphs2):\n        logger.info(f\"Final Subgraph {idx + 1}: {sg.graph_id} with {len(sg.nodes)} nodes.\")\n        if sg.nodes:\n            sg.save_dag_pic(filename=sg.graph_id + \"_pic\", directory=\"dag_images\")\n    logger.info(f\"--- Finished Ex2 ---\\n\")\n\n    # Example 3: Simple linear graph (should not split)\n    logger.info(f\"\\n--- Splitting graph: ex3_linear ---\")\n    linear_graph_ex3 = TaskGraph(graph_id=\"ex3_linear\")\n    linear_graph_ex3.add_nodes([Node(\"L1\", NodeType.DATA_LOAD), Node(\"L2\", NodeType.COMPUTE, dependencies=[\"L1\"]), Node(\"L3\", NodeType.MODEL_TRAIN, dependencies=[\"L2\"])])\n    if linear_graph_ex3.nodes:\n        linear_graph_ex3.save_dag_pic(filename=linear_graph_ex3.graph_id + \"_orig_pic\", directory=\"dag_images\")\n    split_graphs3 = discover_and_split_parallel_paths(linear_graph_ex3)\n    for idx, sg in enumerate(split_graphs3):\n        logger.info(f\"Final Subgraph {idx + 1}: {sg.graph_id} ({len(sg.nodes)} nodes)\")\n        if sg.nodes:\n            sg.save_dag_pic(filename=sg.graph_id + \"_pic\", directory=\"dag_images\")\n    logger.info(f\"--- Finished Ex3 ---\\n\")\n\n    # Example 4: Fan-out, no re-merge (should split into two distinct paths)\n    logger.info(f\"\\n--- Splitting graph: ex4_fanout_only ---\")\n    split_no_merge_graph_ex4 = TaskGraph(graph_id=\"ex4_fanout_only\")\n    split_no_merge_graph_ex4.add_nodes(\n        [\n            Node(\"S_A\", NodeType.DATA_LOAD),\n            Node(\"S_B_exit1\", NodeType.COMPUTE, dependencies=[\"S_A\"]),  # Path 1\n            Node(\"S_C_exit2\", NodeType.COMPUTE, dependencies=[\"S_A\"]),  # Path 2\n        ]\n    )\n    if split_no_merge_graph_ex4.nodes:\n        split_no_merge_graph_ex4.save_dag_pic(filename=split_no_merge_graph_ex4.graph_id + \"_orig_pic\", directory=\"dag_images\")\n    split_graphs4 = discover_and_split_parallel_paths(split_no_merge_graph_ex4)\n    for idx, sg in enumerate(split_graphs4):\n        logger.info(f\"Final Subgraph {idx + 1}: {sg.graph_id} ({len(sg.nodes)} nodes)\")\n        if sg.nodes:\n            sg.save_dag_pic(filename=sg.graph_id + \"_pic\", directory=\"dag_images\")\n    logger.info(f\"--- Finished Ex4 ---\\n\")\n\n    # Example 5: Three-way re-merge (should split into three paths)\n    logger.info(f\"\\n--- Splitting graph: ex5_3way_reconverge ---\")\n    three_way_graph_ex5 = TaskGraph(graph_id=\"ex5_3way_reconverge\")\n    three_way_graph_ex5.add_nodes(\n        [\n            Node(\"3W_A\", NodeType.DATA_LOAD),\n            Node(\"3W_B\", NodeType.DATA_LOAD),\n            Node(\"3W_E\", NodeType.DATA_LOAD),\n            Node(\"3W_C\", NodeType.COMPUTE, dependencies=[\"3W_A\", \"3W_B\", \"3W_E\"]),  # 3-way re-convergence\n            Node(\"3W_D_exit\", NodeType.MODEL_TRAIN, dependencies=[\"3W_C\"]),\n        ]\n    )\n    if three_way_graph_ex5.nodes:\n        three_way_graph_ex5.save_dag_pic(filename=three_way_graph_ex5.graph_id + \"_orig_pic\", directory=\"dag_images\")\n    split_graphs5 = discover_and_split_parallel_paths(three_way_graph_ex5)\n    for idx, sg in enumerate(split_graphs5):\n        logger.info(f\"Final Subgraph {idx + 1}: {sg.graph_id} ({len(sg.nodes)} nodes)\")\n        if sg.nodes:\n            sg.save_dag_pic(filename=sg.graph_id + \"_pic\", directory=\"dag_images\")\n    logger.info(f\"--- Finished Ex5 ---\\n\")\n\n    logger.info(\"All examples processed. Check the 'dag_images' folder.\")\n"
  },
  {
    "path": "siirl/execution/metric_worker/metric_worker.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport ray\nimport asyncio\n\nfrom ray.actor import ActorHandle\n\nfrom typing import Optional, Any\nfrom loguru import logger\nfrom tensordict import TensorDict\n\nfrom siirl.execution.metric_worker.utils import *\nfrom siirl.utils.metrics.metric_utils import *\n\n# Special metric configurations where specific aggregation logic is needed\n# e.g., \"graph_output_handling\" uses MaxMetric since only rollout_tp 0 needs data buffer handling\nSpecial_Metric = {\n    \"graph_output_handling\": MaxMetric,\n}\n\nclass MetricClient():\n    \"\"\"Client class for interacting with the MetricWorker actor\n    \n    Provides methods to submit metrics, wait for submissions to complete,\n    and retrieve final aggregated metrics from the worker.\n    \"\"\"\n    def __init__(self, metric_worker: ActorHandle):\n        \"\"\"Initialize MetricClient with a reference to the MetricWorker actor\n        \n        Args:\n            metric_worker: Ray actor handle for the MetricWorker instance\n        \"\"\"\n        self.metric_worker = metric_worker\n        self.fut = []  # List to track pending metric submission futures\n        \n    def stop(self):\n        \"\"\"Stop the metric worker and terminate its processing loop\"\"\"\n        self.is_running = False\n        ray.get(self.metric_worker.stop.remote())\n    \n    def submit_metric(self, metrics: dict, world_size):\n        \"\"\"Submit a dictionary of metrics to the worker for aggregation\n        \n        Args:\n            metrics: Dictionary containing metric names and values\n            world_size: Total number of processes in the distributed system\n        \"\"\"\n        self.fut.append(self.metric_worker.submit_metric.remote(metrics, world_size))\n\n    \n    def wait_submit(self):\n        \"\"\"Wait for all pending metric submissions to complete\"\"\"\n        ray.get(self.fut)\n        self.fut = []  # Clear the list after all futures are resolved\n    \n    def wait_final_res(self):\n        \"\"\"Retrieve the final aggregated metrics from the worker\n        \n        Returns:\n            Dictionary of aggregated metrics\n        \"\"\"\n        metrics = ray.get(self.metric_worker.wait_final_res.remote())\n        return metrics\n    \n    def compute_local_data_metric(self, data: TensorDict, world_size: int):\n        \"\"\"Compute and submit metrics from local data\n        \n        Extracts necessary fields from the TensorDict, computes metrics,\n        and submits them to the worker.\n        \n        Args:\n            data: TensorDict containing the data to process\n            world_size: Total number of processes in the distributed system\n        \"\"\"\n        need_key = [\"responses\", \"attention_mask\", \"token_level_scores\",\n            \"token_level_rewards\", \"advantages\", \"returns\", \"values\", \"response_mask\", \"__num_turns__\"]\n        # Add optional keys if present in data\n        need_key = [key for key in need_key if key in data.keys()]\n        \n        need_data = data.select(*need_key)\n        self.fut.append(self.metric_worker.submit_metric.remote(\n            compute_data_metric(need_data), world_size)\n        )\n\n    def compute_local_throughout_metrics(self, data: TensorDict, timing_raw: dict, n_gpu: int, world_size: int):\n        \"\"\"Compute and submit throughput metrics\n        \n        Args:\n            data: TensorDict containing relevant data (e.g., token counts)\n            timing_raw: Dictionary containing raw timing data\n            n_gpu: Number of GPUs used (should be 1)\n            world_size: Total number of processes in the distributed system\n        \"\"\"\n        need_key = [\"global_token_num\"]\n        need_data = data.select(*need_key)\n        self.fut.append(self.metric_worker.submit_metric.remote(\n            compute_throughout_metrics(need_data, timing_raw, n_gpu), world_size)\n        )\n        \n    def compute_local_timing_metrics(self, data: TensorDict, timing_raw: dict, world_size: int):\n        \"\"\"Compute and submit timing metrics\n        \n        Args:\n            data: TensorDict containing relevant data\n            timing_raw: Dictionary containing raw timing data\n            world_size: Total number of processes in the distributed system\n        \"\"\"\n        self.fut.append(self.metric_worker.submit_metric.remote(\n            compute_timing_metrics(data, timing_raw), world_size)\n        )\n    \n    def process_local_validation_metrics(self, data_sources: list[str], sample_inputs: list[str], \n                                        infos_dict: dict[str, list[Any]], sample_turns: list[int], world_size: int):\n        \"\"\"Process and submit validation metrics\n        \n        Args:\n            data_sources: List of data source names\n            sample_inputs: List of sample input strings\n            infos_dict: Dictionary containing information about validation samples\n            sample_turns: List of turn counts for each sample\n            world_size: Total number of processes in the distributed system\n        \"\"\"\n        self.fut.append(self.metric_worker.submit_metric.remote(\n            process_validation_metrics(data_sources, sample_inputs, infos_dict, sample_turns), world_size)\n        )\n    \n    \n    \n\n@ray.remote(num_cpus=1)\nclass MetricWorker:\n    \"\"\"Ray actor responsible for aggregating metrics from distributed processes\n    \n    Runs an asynchronous loop to process incoming metrics, aggregate them\n    across all processes, and provide final results when requested.\n    \"\"\"\n    def __init__(self) -> None:\n        self.metric_queue = asyncio.Queue()  # Queue for incoming metric submissions\n        self.is_running = False  # Flag to control the processing loop\n        self.process_task: Optional[asyncio.Task] = None  # Task for the main processing loop\n        self.step = 0  # Current step counter (not actively used in shown code)\n        self.final_metrics = {}  # Aggregated final metrics\n        self.working_metrics = {}  # Metrics currently being collected/aggregated\n    \n    async def start(self):\n        \"\"\"Start the metrics processing loop\n        \n        Initializes and starts the asynchronous loop that processes metrics\n        from the queue.\n        \"\"\"\n        if self.is_running:\n            return\n        \n        self.is_running = True\n        self.process_task = asyncio.create_task(self._process_metrics_loop())\n\n    async def submit_metric(self, metric: dict, world_size: int):\n        \"\"\"Submit a metric dictionary to the worker's processing queue\n        \n        Args:\n            metric: Dictionary of metric names and values\n            world_size: Total number of processes in the distributed system\n        \"\"\"\n        await self.metric_queue.put((metric, world_size))\n\n        \n    \n    async def stop(self):\n        \"\"\"Stop the metrics processing loop and clean up resources\"\"\"\n        self.is_running = False\n        if self.process_task:\n            self.process_task.cancel()\n            try:\n                await self.process_task\n            except asyncio.CancelledError:\n                pass  # Expected when task is cancelled\n    \n    async def compute_metric(self, metric_name, metrics):\n        \"\"\"Compute the final aggregated value for a metric\n        \n        Uses the appropriate metric function to aggregate values from all processes.\n        \n        Args:\n            metric_name: Name of the metric to compute\n            metrics: List of Metric objects containing values from each process\n        \"\"\"\n        metric_func = MetricFunc(metric_name)\n        res = metric_func(metrics)\n        self.working_metrics.pop(metric_name)  # Remove from working set after computation\n        \n        # Rename timing metrics for consistency in output\n        if metric_name.startswith(\"timing_s/\"):\n            metric_name = metric_name.replace(\"timing_s/\", \"perf/delta_time/\")\n            \n        self.final_metrics[metric_name] = res\n        \n           \n    async def parse_metric(self, metric_data: tuple):\n        \"\"\"Process incoming metric data and aggregate when all processes have submitted\n        \n        Collects metric values from each process and triggers computation when\n        all values (one per process) have been received.\n        \n        Args:\n            metric_data: Tuple containing (metric_dict, world_size)\n        \"\"\"\n        metric_dict, world_size = metric_data\n        \n        for key, value in metric_dict.items():\n            metric = Metric(name=key, value=value, world_size=world_size)\n            \n            # Add to working metrics or compute if all values are collected\n            if key in self.working_metrics:\n                metrics = self.working_metrics[key]\n                metrics.append(metric)\n                # Check if we have received values from all processes\n                if len(metrics) == world_size:\n                    await self.compute_metric(key, metrics)\n            else:\n                self.working_metrics[key] = [metric]\n            \n    \n    async def _process_metrics_loop(self):\n        \"\"\"Main loop for processing metrics from the queue\n        \n        Continuously retrieves and processes metric data while the worker is running.\n        \"\"\"\n        while self.is_running:\n            metric_data = await self.metric_queue.get()            \n            await self.parse_metric(metric_data)\n            \n\n    async def wait_final_res(self):\n        \"\"\"Wait for all metrics to be processed and return the final results\n        \n        Ensures all remaining metrics in the queue are processed, computes any\n        remaining aggregated values, and returns the final metrics.\n        \n        Returns:\n            Dictionary of final aggregated metrics\n        \"\"\"\n        await self.stop()\n        \n        # Process any remaining metrics in the queue\n        while self.metric_queue.qsize():      \n            metric_data = await self.metric_queue.get()\n            await self.parse_metric(metric_data)\n        \n        # Compute any metrics still in working set\n        items = list(self.working_metrics.items())\n        for key, value in items:\n            await self.compute_metric(key, value)\n        \n        # Restart the worker for potential future use\n        await self.start()\n        \n        # Capture and reset metrics before returning\n        final_metrics = self.final_metrics\n        self.final_metrics = {}\n        self.working_metrics = {}\n        return final_metrics"
  },
  {
    "path": "siirl/execution/metric_worker/utils.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport torch\nfrom statistics import mean\nfrom typing import List, Optional, Union, Any, Dict\nfrom dataclasses import dataclass\nfrom tensordict import TensorDict\n\n\n@dataclass\nclass Metric:\n    name: None\n    value: Any\n    world_size: None\n\ndef MetricFunc(name: str):\n    if \"min\" in name:\n        return MinMetric\n    elif \"max\" in name:\n        return MaxMetric\n    elif \"sum\" in name or \"total\" in name:\n        return SumMetric\n    else:\n        return MeanMetric\n\ndef SumMetric(metrics: List[Metric]):\n    value = [v\n        for metric in metrics\n        for v in (metric.value if isinstance(metric.value, list) else [metric.value])]\n    return sum(value)\n\ndef MeanMetric(metrics: List[Metric]):\n    value = [v\n        for metric in metrics\n        for v in (metric.value if isinstance(metric.value, list) else [metric.value])]\n    return mean(value)\n\n\ndef MaxMetric(metrics: List[Metric]):\n    value = [v\n        for metric in metrics\n        for v in (metric.value if isinstance(metric.value, list) else [metric.value])]\n    return max(value)\n\ndef MinMetric(metrics: List[Metric]):\n    value = [v\n        for metric in metrics\n        for v in (metric.value if isinstance(metric.value, list) else [metric.value])]\n    return min(value)"
  },
  {
    "path": "siirl/execution/rollout_flow/multi_agent/multiagent_generate.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport ray\nimport torch\nimport asyncio\nimport random\nimport copy\nimport numpy as np\nimport torch.distributed  as dist\n\nfrom asyncio import Queue\nfrom siirl.models.loader import load_tokenizer\nfrom tensordict import TensorDict\nfrom uuid import uuid4\nfrom .utils import AgentOutput, AgentOutputStatus\nfrom typing import Dict, List, Any, Tuple, Optional, Union\nfrom codetiming import Timer\n\nfrom loguru import logger\nfrom siirl.engine.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer\nfrom siirl.engine.fsdp_workers import ActorRolloutRefWorker\nfrom siirl.execution.dag import TaskGraph, Node, NodeRole, NodeType\nfrom siirl.params import RolloutArguments, ActorRolloutRefArguments\nfrom siirl.dag_worker.dag_utils import remove_prefix_from_dataproto\n\nclass MultiAgentLoop():\n    def __init__(self, dag, config: ActorRolloutRefArguments, node_workers:Dict, local_dag:TaskGraph, databuffer:List[\"ray.actor.ActorHandle\"], placement_mode: str = 'colocate'):\n        # dely import Dag after dagworker finish init\n        from siirl.dag_worker.dagworker import DAGWorker\n        assert config.rollout.name == 'sglang', \"MultiAgent only support sglang because vllm can't sleep in multi times\"\n        self.dag:DAGWorker = dag\n        self.graph = local_dag\n        self.placement_mode = placement_mode\n        self.rollout_config = config.rollout\n        self.internal_data_cache: Dict[str, Queue] = {}\n        self.data_buffers = databuffer\n        self.workers = node_workers\n        self.max_model_len = int(self.rollout_config.max_model_len or self.rollout_config.prompt_length + self.rollout_config.response_length)\n        self.max_model_len = min(self.max_model_len, self.rollout_config.prompt_length + self.rollout_config.response_length)\n        self._parse_graph(local_dag)\n        self.finish_generate = False \n        assert placement_mode == 'colocate' #in ['colocate', 'spread']\n        if self.rollout_config.multi_turn.max_assistant_turns is None:\n            self.rollout_config.multi_turn.max_assistant_turns = 1\n    \n        \n    def _parse_graph(self, graph:TaskGraph):\n        node_queue = graph.get_entry_nodes()\n        visited_nodes = set()\n        self.node_queue = []\n        while node_queue:\n            cur_node = node_queue.pop(0)\n            if cur_node.node_role != NodeRole.ROLLOUT:\n                break\n            self.node_queue.append(cur_node)\n            next_nodes = graph.get_downstream_nodes(cur_node.node_id)\n            for n in next_nodes:\n                if n.node_id not in visited_nodes:\n                    node_queue.append(n)\n        tail_node = self.node_queue[-1]\n        tail_worker :ActorRolloutRefWorker = self.workers[self._generate_node_worker_key(tail_node)]\n        self.tail_device_mesh = tail_worker.rollout.get_device_mesh()\n\n\n    def _generate_node_worker_key(self, node: Node) -> str:\n        \"\"\"Generates a unique string key for a node's worker instance.\"\"\"\n        return f\"{node.agent_group}_{node.node_type.value}_{node.node_role.value}\"\n    \n    def node_if_local(self, node):\n    # used in spread mode to judge agent node if in current gpu_worker\n        pass\n\n    def _preprocess(self, batch:TensorDict) -> List[str]:\n        '''Preprocess data from dataloader and return prompt to generate task'''\n        n = 1 if batch.meta_info.get(\"validate\", False) else self.rollout_config.n\n        batch = batch.repeat(n, interleave=True)\n        raw_prompts = batch.non_tensor_batch[\"raw_prompt\"]\n        reward_model = batch.non_tensor_batch['reward_model'] if self.rollout_config.agent.rewards_with_env else None\n        raw_prompts = [p.tolist() for p in raw_prompts]\n        if reward_model is None:\n            ground_truth = [-1] * len(raw_prompts)\n        else:\n            ground_truth = [reward['ground_truth'] for reward in reward_model]\n        return raw_prompts, ground_truth\n\n    def _generate_key(self, cur_node: Node, next_node: Node, batch_id: int, global_bs: int = 0) :\n        \"\"\"\n        Generates a unique key for routing data between nodes in the DAG, considering their Data Parallel (DP) configurations.\n        The key ensures data is correctly routed from the current node to the next node, accounting for differences\n        in DP sizes (e.g., when the current node's DP size is larger or smaller than the next node's).\n        Args:\n            cur_node: Current node in the DAG (source of the data).\n            next_node: Next node in the DAG (destination for the data).\n            batch_id: Index of the sample within the local batch (for fine-grained routing).\n            global_bs: Global batch size (optional, used when DP sizes differ).\n        Returns:\n            str: Unique key formatted as \"{next_node_id}_{destination_dp_rank}\" for data routing.\n        \"\"\"\n        cur_dp_size, cur_dp_rank, *_ = self.dag._get_node_dp_info(cur_node)\n        next_dp_size, next_dp_rank, *_ = self.dag._get_node_dp_info(next_node)\n        if cur_dp_size == cur_dp_size:\n            return f\"{next_node.node_id}_{next_dp_rank}\"\n        elif cur_dp_size > next_dp_size:\n            assert cur_dp_size % next_dp_size == 0, f\"dp size of {cur_node.node_id} should div by {next_node.node_id}\"\n            return f\"{next_node.node_id}_{cur_dp_rank / next_dp_size}\"\n        else:\n            assert next_dp_size % cur_dp_size == 0, f\"dp size of {next_node.node_id} should div by {cur_node.node_id}\"\n            # todo(hujr): may not suitable in GRPO\n            next_rank_range = next_dp_size / cur_dp_size\n            next_bs_range = global_bs / next_dp_size\n            next_rank = batch_id / next_bs_range + cur_dp_rank * next_rank_range\n            return f\"{next_node.node_id}_{next_rank}\"\n          \n    async def async_put_data(self, key: str, value: Tuple[AgentOutput, str], source_dp_size: int, dest_dp_size: int, timing_raw: Dict[str, float]):\n        \"\"\"\n        Asynchronously puts data into local cache or distributed data buffers, based on source and destination DP sizes.\n        - Uses local cache when source and destination DP sizes match (no cross-DP communication needed).\n        - Uses distributed buffers when DP sizes differ (requires cross-DP data sharing).\n\n        Args:\n            key: Unique key to identify the data (generated by `_generate_key`).\n            value: Data to store (either an `AgentOutput` object or a string).\n            source_dp_size: DP size of the source node.\n            dest_dp_size: DP size of the destination node.\n            timing_raw: Dictionary to track timing metrics (not yet implemented).\n        \"\"\"\n        if  source_dp_size == dest_dp_size:\n            if isinstance(value, AgentOutput):\n                if key not in self.internal_data_cache:\n                    self.internal_data_cache[key] = Queue()\n                await self.internal_data_cache[key].put(value)\n            elif isinstance(value, str):\n                self.internal_data_cache[key] = value\n            else:\n                raise NotImplementedError(\"This Should not Happen\")\n        else:   \n            # random save to databuffers\n            buffer = random.choice(self.data_buffers)\n            await buffer.put.remote(key, value)\n            \n    async def async_get_envdata(self, key: str, timing_raw: Dict[str, float]):\n        \"\"\"\n        Asynchronously retrieves environment-related data (e.g., observations) from local cache or distributed buffers.\n        Checks local cache first for efficiency; falls back to distributed buffers if not found.\n\n        Args:\n            key: Unique key identifying the environment data.\n            timing_raw: Dictionary to track timing metrics (not yet implemented).\n\n        Returns:\n            The requested data (if found) or None (if not found).\n        \"\"\"\n        data = None  \n        if key in self.internal_data_cache:\n            data = self.internal_data_cache.pop(key, None)\n        else:\n            tasks = [buffer.pop.remote(key) for buffer in self.data_buffers]\n           \n            temp_data = await asyncio.gather(*tasks)\n            # temp_data = self.data_buffers.get(key) \n            data = [t for t in temp_data if t is not None]\n        return data[0] if data else None\n    \n    async def async_get_data(self, key: str, timing_raw: Dict[str, float]):\n        \"\"\"\n        Asynchronously retrieves generated data (e.g., AgentOutputs) from local cache or distributed buffers.\n        \n        Handles queue-based data in local cache (for multiple entries) and aggregates results from distributed buffers.\n\n        Args:\n            key: Unique key identifying the data (generated by `_generate_key`).\n            timing_raw: Dictionary to track timing metrics (not yet implemented).\n\n        Returns:\n            List of retrieved data entries (or None if no data found).\n        \"\"\"    \n        data = None  \n        if key in self.internal_data_cache:\n            queue:Queue = self.internal_data_cache.get(key)\n            while queue.qsize() > 0:\n                if data:\n                    data.append(await queue.get())\n                else:\n                    data = [await queue.get()]\n        else:\n            tasks = [buffer.get_queue.remote(key) for buffer in self.data_buffers]\n            temp_data = await asyncio.gather(*tasks)\n            # temp_data = self.data_buffers.get(key) \n            data = [item for t in temp_data if t is not None for item in t]\n            \n        return data\n    \n\n    async def spread_task(self, cur_node, node_idx, batch_idx):\n        ''' Not support now'''\n        while True:\n            key = self._generate_key(batch_idx, cur_node.dp_rank, cur_node.node_id)\n            prompt,should_stop = self.databuffer.get(key)\n            if multiturn > max_multiturn:\n                return response, response_mask\n            if should_stop:\n                if node_idx != len(self.node_queue) - 1:\n                    next_node = self.node_queue[node_idx + 1] if node_idx < len(self.node_queue) - 1  else self.node_queue[0]\n                    next_key = self._generate_key(batch_idx, next_node.dp_rank, next_node.node_id)\n                    self.databuffer.put(next_key, [[], should_stop]) \n                return response, response_mask\n            response = await node_worker.rollout.generate(XXXX)\n            prompt = prompt + response\n            if cur_node.node_id == self.node_queue[-1]:\n                # last agent need to interaction with env\n                tool_response, rewards, should_stop = self.env.execute(prompt)\n                prompt = prompt + tool_response\n            next_node = self.node_queue[node_idx + 1] if node_idx < len(self.node_queue) - 1  else self.node_queue[0]\n            next_key = self._generate_key(batch_idx, next_node.dp_rank, next_node.node_id)\n            self.databuffer.put(next_key, [prompt, should_stop]) \n            multiturn = multiturn + 1\n                \n    async def generate_spread(self):\n        ''' Not support now'''\n        for i in range(len(self.node_queue)):\n            cur_node = self.node_queue[i]\n            if node_if_local(cur_node) is False:\n                continue\n            dp_rank,dp_size,tp_rank,tp_size = get_node_info(cur_node)\n            node_worker = self.workers[self._generate_node_worker_key(cur_node)]\n            # wakeup before generate\n            node_worker.wake_up()\n            if tp_rank == 0:\n                bs = get_batch_size(cur_node)\n                tasks = []\n                for bs_idx in range(bs):\n                    tasks.append(spread_task(cur_node, i, bs_idx))\n                res = await asyncio.gather(*tasks)\n            barrier()\n            node_worker.sleep()\n            return res\n    async def check_colocate_running(self, finished_res: Dict, visited_agentoutputs: Dict):\n        \"\"\"\n        Asynchronously checks whether the current worker should continue running (local_running status)\n        by verifying if all tracked samples have been fully processed across all DAG nodes.\n        Args:\n            finished_res: Dictionary tracking finished samples. Key = node ID, Value = set of \n                        request IDs that have been fully processed for that node.\n            visited_agentoutputs: Set (or dict-like) of request IDs representing samples the \n                                current worker has fetched/processed in the current cycle.\n\n        Returns:\n            bool: True if the worker should continue running (unprocessed samples remain), \n                False if all visited samples are finished (worker can stop).\n        \"\"\"\n        finish = True\n        for node in self.node_queue:\n            if node.node_id in finished_res:\n                if len(finished_res[node.node_id]) == len(visited_agentoutputs):\n                    finish = False\n                elif len(finished_res[node.node_id]) > len(visited_agentoutputs):\n                    assert False, \"This should not happen\"\n                else:\n                    finish = True\n            else:\n                finish = True\n        return finish\n    \n    async def colocate_task(self, agent_output:AgentOutput, agent_res:Dict,  finished_res: Dict, cur_node: Node, node_idx: int, sampling_params: Dict[str, Any], global_bs: int,  timing_raw: Dict[str, float]):\n        \"\"\"\n        Asynchronous task for generate with a single `AgentOutput` in colocated mode.\n        \n        Handles end-to-end processing for one sample: environment observation fetching, prompt preprocessing,\n        model generation, response postprocessing, environment step execution (if enabled), and data propagation\n        to the next node in the DAG. Also tracks finished samples and updates result dictionaries.\n        \n        Args:\n            agent_output: `AgentOutput` object containing the sample's prompt, metadata, and status.\n            agent_res: Global dictionary to store processed `AgentOutput` results (key: node ID, value: request ID → AgentOutput).\n            finished_res: Global dictionary to track fully processed samples (key: node ID, value: set of request IDs).\n            cur_node: Current DAG node being processed (defines agent logic and environment config).\n            node_idx: Index of `cur_node` in the DAG node queue (for next node lookup).\n            sampling_params: Model sampling hyperparameters (e.g., temperature, top_p) for sequence generation.\n            global_bs: Global batch size (DP size × local batch size) for consistent data routing.\n            timing_raw: Dictionary to record raw timing metrics (e.g., preprocessing, generation, environment step latency).\n        \"\"\"\n        cur_dp_size, cur_dp_rank, *_ = self.dag._get_node_dp_info(cur_node)\n        node_worker:ActorRolloutRefWorker = self.workers[self._generate_node_worker_key(cur_node)]\n        next_node = self.node_queue[node_idx + 1] if node_idx < len(self.node_queue) - 1  else self.node_queue[0]\n        next_key = self._generate_key(cur_node, next_node, agent_output.batch_id, global_bs)\n        next_dp_size, *_ = self.dag._get_node_dp_info(next_node)\n        obs = None\n        if agent_output.status !=  AgentOutputStatus.RUNNING:\n            if cur_node.node_id not in finished_res:\n                finished_res[cur_node.node_id] = set()\n            \n            finished_res[cur_node.node_id].add(agent_output.request_id)\n            # pre agent use same rewards with last agent\n            await self.async_put_data(next_key, agent_output, cur_dp_size, next_dp_size, timing_raw)\n            return\n        if cur_node.agent_options and cur_node.agent_options.obs_with_env:\n            obs = await self.async_get_envdata(agent_output.request_id + f'_{cur_node.agent_group}', timing_raw)\n        agent_output.original_prompt, agent_output.templated_prompt = cur_node.agent_process.apply_pre_process(prompt=agent_output.original_prompt, obs = obs)\n        agent_output.templated_prompt = agent_output.templated_prompt[:self.rollout_config.prompt_length]\n\n        response = await node_worker.rollout.generate(\n            request_id=agent_output.request_id, prompt_ids=agent_output.templated_prompt, sampling_params=sampling_params\n            )\n        if len(response) == 0:\n            # if response is None, padding response some prompt for training\n            response = \"<|im_end|>\"\n        agent_output.original_prompt, agent_output.templated_prompt, agent_output.response_mask \\\n            = cur_node.agent_process.apply_post_process(oridinal_prompt = agent_output.original_prompt, templated_prompt = agent_output.templated_prompt, response = response)     \n                   \n        # if have env\n        if cur_node.agent_options and cur_node.agent_options.obs_with_env:\n            if cur_node.agent_process.env:    \n                pre_agent_actions = {}\n                for i in range(cur_node.agent_group):\n                    pre_agent_actions[i] = await self.async_get_envdata(agent_output.request_id + f'_{i}', timing_raw)\n            \n                for env_id, env_manager in enumerate(cur_node.agent_process.env_managers):\n                    # only support one env now\n                    if agent_output.request_id not in env_manager:\n                        env_class = cur_node.agent_process.env[env_id]\n                        env_manager[agent_output.request_id + f'{cur_node.agent_group}'] = env_class()\n                    env_instance = env_manager[agent_output.request_id + f'{cur_node.agent_group}']\n                    \n                    \n                    pre_agent_actions = [data for data in list(pre_agent_actions.values()) if data is not None]\n                    next_obs, rewards, should_stop = await env_instance.step(actions = pre_agent_actions + [agent_output.original_prompt], ground_truth = agent_output.ground_truth)\n                    \n                    agent_output.rewards = rewards\n                    agent_output.original_prompt = next_obs[-1]\n                    \n                    if should_stop:\n                        agent_output.status = AgentOutputStatus.ENV_FINISH\n                    # todo: add multienv process\n                if isinstance(next_obs, list) and (isinstance(next_obs[0], list) or isinstance(next_obs[0], str)):\n                    # have multi-agent obs\n                    assert len(next_obs) == len(self.node_queue), f\"env return {len(next_node)} obs, should equal agent num {len(self.node_queue)}\"\n                    # this data may be get in last node ,force put to databuffer temporarily\n                    for i in range(cur_node.agent_group):\n                        if next_obs[i] is None:\n                            assert False\n                        await self.async_put_data(agent_output.request_id + f'_{i}', next_obs[i], 2, 4, timing_raw)\n            else:\n                # this data will be get in last node ,force put to databuffer temporarily\n                assert isinstance(agent_output.original_prompt, str)\n                if agent_output.original_prompt is None:\n                    assert False\n                await self.async_put_data(agent_output.request_id + f'_{cur_node.agent_group}', agent_output.original_prompt, 2, 4, timing_raw)\n                \n        input_and_response = agent_output.templated_prompt\n        agent_output.templated_prompt = input_and_response[: len(agent_output.templated_prompt) - len(agent_output.response_mask)]\n        agent_output.response_mask = agent_output.response_mask[:self.rollout_config.response_length]\n        if len(agent_output.response_mask) == 0:\n            # multi-agent may response none\n            agent_output.response_id = []\n        else:\n            agent_output.response_id = input_and_response[-len(agent_output.response_mask) :]\n        if cur_node.node_id not in agent_res:\n            agent_res[cur_node.node_id] = {}\n            agent_res[cur_node.node_id][agent_output.request_id] = []\n        if agent_output.request_id not in agent_res[cur_node.node_id]:\n            agent_res[cur_node.node_id][agent_output.request_id] = []\n        if self.rollout_config.multi_turn.use_all_traj:\n            agent_res[cur_node.node_id][agent_output.request_id].append(copy.deepcopy(agent_output))\n        else:\n            agent_res[cur_node.node_id][agent_output.request_id]=[copy.deepcopy(agent_output)]\n        # last node need to add turn\n        if node_idx == len(self.node_queue) - 1:\n            agent_output.turn = agent_output.turn + 1\n            if agent_output.turn >= self.rollout_config.multi_turn.max_assistant_turns:\n                agent_output.status = AgentOutputStatus.Turn_FINISH\n            if len(agent_output.templated_prompt) >= self.max_model_len:\n                agent_output.status = AgentOutputStatus.LENGTH_FINISH\n\n        if agent_output.status !=  AgentOutputStatus.RUNNING:\n            if cur_node.node_id not in finished_res:\n                finished_res[cur_node.node_id] = set()\n            finished_res[cur_node.node_id].add(agent_output.request_id)\n       \n        await self.async_put_data(next_key, agent_output, cur_dp_size, next_dp_size, timing_raw)\n        \n        return\n    \n    async def generate_colocate(self, bs, sampling_params: Dict[str, Any], timing_raw: Dict[str, float]):\n        \"\"\"\n        Asynchronously generates sequences in **colocated mode** (model and data reside on the same worker),\n        handling distributed coordination (Data Parallelism/DP + Tensor Parallelism/TP) across all nodes in the DAG.\n        \n        This function manages a loop to fetch input data, dispatch generation tasks, and synchronize across\n        distributed ranks until all samples in the batch are fully processed.\n        \n        Args:\n            bs: Local batch size (number of samples to process per individual worker).\n            sampling_params: Dictionary of model sampling hyperparameters (e.g., temperature, top_p, repetition_penalty).\n            timing_raw: Dictionary to record raw timing metrics (e.g., data fetch latency, task execution time) for debugging/benchmarking.\n        \n        Returns:\n            agent_res: Dictionary mapping node IDs (str) to lists of `AgentOutput` objects. Each `AgentOutput` contains\n                    the generated sequence, prompt metadata, and other task-related results for a single sample.\n        \"\"\"\n        agent_res: Dict[str, List[AgentOutput]] = {}\n        finished_res: Dict[str, set] = {}\n        agent_num = len(self.node_queue)  \n        global_running = True\n        local_running = True  \n        while global_running:\n            for i in range(agent_num):\n                visited_agentoutputs = set()\n                cur_node:Node = self.node_queue[i]  \n                cur_dp_size, cur_dp_rank, cur_tp_rank, *_ = self.dag._get_node_dp_info(cur_node)\n                if i == 0:\n                    global_bs = cur_dp_size * bs\n                node_worker :ActorRolloutRefWorker = self.workers[self._generate_node_worker_key(cur_node)]\n                key = f\"{cur_node.node_id}_{cur_dp_rank}\"\n                workers = []\n                await node_worker.rollout.wake_up()\n                \n                # assert self.tran_bs * self.rollout_config.n % cur_dp_size == 0, f\"global batch size is f{self.tran_bs * self.rollout_config.n} can't div by f{cur_dp_size} in node {cur_node.node_id}\"\n                \n                if cur_tp_rank == 0 and local_running:\n                    while True:\n                        agent_outputs:List[AgentOutput] = await self.async_get_data(key, timing_raw)               \n                        if agent_outputs is not None:\n                            for agent_output in agent_outputs:\n                                visited_agentoutputs.add(agent_output.request_id)\n                                worker_task = asyncio.create_task(\n                                        self.colocate_task(agent_output = agent_output, \n                                                        agent_res = agent_res, \n                                                        finished_res = finished_res,\n                                                        cur_node = cur_node, \n                                                        node_idx = i, \n                                                        sampling_params = sampling_params, \n                                                        global_bs = global_bs,\n                                                        timing_raw = timing_raw)\n                                    )\n                                workers.append(worker_task)\n                        \n                        if len(visited_agentoutputs) == bs:\n                            await asyncio.gather(*workers)\n                            break   \n\n                torch.distributed.barrier(node_worker.rollout.get_device_mesh()[\"tp\"].get_group())  \n\n\n                # Note: in async mode, can't global barrier\n                await node_worker.rollout.sleep()        \n                local_running = await self.check_colocate_running(finished_res, visited_agentoutputs)\n                # tp 0 broadcast to other tp\n                tp_group = self.tail_device_mesh[\"tp\"].get_group()\n                tp_local_rank = self.tail_device_mesh[\"tp\"].get_local_rank()\n                src_local_rank = 0\n                src_global_rank = self.tail_device_mesh[\"tp\"].mesh.tolist()[src_local_rank]\n                broadcast_list = [None]\n                if tp_local_rank == src_local_rank:\n                    broadcast_list[0] = local_running\n\n                dist.broadcast_object_list(\n                    object_list=broadcast_list,  \n                    src=src_global_rank,    \n                    group=tp_group          \n                )\n                local_running = broadcast_list[0]\n                        \n                        \n                finish_flag_tensor = torch.tensor(0 if local_running else 1, device=\"cuda\" if torch.cuda.is_available() else \"cpu\")\n                dist.all_reduce(finish_flag_tensor, op=dist.ReduceOp.SUM)\n                total_finish = finish_flag_tensor.item()\n                if total_finish == dist.get_world_size():\n                    global_running = False\n                else:\n                    global_running = True\n                \n        return agent_res   \n    def _postprocess(self, agent_outputs: Dict[str, List[AgentOutput]], metrics: Dict) -> TensorDict:\n        \"\"\"\n        Postprocesses generated agent outputs into a structured TensorDict object.\n        \n        Combines prompts and responses into formatted tensors (input_ids, attention_mask, etc.)\n        with proper padding and metadata, handling multiple nodes in the DAG.\n        \n        Args:\n            agent_outputs: Dictionary mapping node IDs to lists of AgentOutput objects containing\n                        generated responses, prompts, and metadata for each sample.\n        \n        Returns:\n            TensorDict object containing concatenated batch data (tensors) and metadata from all nodes.\n        \"\"\"\n        # NOTE: Consistent with batch version of generate_sequences in vllm_rollout_spmd.py\n        # - prompts: Left-padded to fixed length\n        # - responses: Right-padded to fixed length\n        # - input_ids: Concatenation of prompt and response token IDs\n        # - attention_mask: Combines prompt mask (left) and response mask (right)\n        #   Format: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]\n        # - position_ids: Sequential numbering for valid tokens (masked tokens get 0)\n        #   Format: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]\n        async def _single_postprocess(agent_outputs: List[AgentOutput], node: Node, metrics: Dict):\n            \"\"\"\n            Helper function to postprocess outputs for a single node in the DAG.\n            \n            Args:\n                agent_outputs: List of AgentOutput objects for the current node.\n                node: Current DAG node (contains agent processor and tokenizer).\n            \n            Returns:\n                TensorDict with processed tensors for the current node.\n            \"\"\"\n            # Sort agent outputs by batch ID to maintain original order\n            cur_agent_outputs = list(agent_outputs.values())\n            cur_agent_outputs = [\n                list(agent_output)  \n                for agent_output in sorted(cur_agent_outputs, key=lambda x: x[0].batch_id)\n            ]\n            prompt_texts = [step_output.templated_prompt for agent_output in cur_agent_outputs for step_output in agent_output]\n            prompt_texts = node.agent_process.tokenizer.batch_decode(prompt_texts, skip_special_tokens=True)\n            node.agent_process.tokenizer.padding_side = \"left\"\n            input_ids = [{\"input_ids\": step_output.templated_prompt} for agent_output in cur_agent_outputs for step_output in agent_output]\n            batch_size = len(input_ids)\n            world_size = dist.get_world_size()\n            pad_batch_size = 0\n            if remainder := batch_size % world_size:\n                # need pad    \n                pad_batch_size = world_size - remainder\n                for _ in range(pad_batch_size):\n                    input_ids.append(input_ids[0].copy())\n            outputs = node.agent_process.tokenizer.pad(\n                input_ids,\n                padding=\"max_length\",\n                max_length=self.rollout_config.prompt_length,\n                return_tensors=\"pt\",\n                return_attention_mask=True,\n            )\n            prompt_ids, prompt_attention_mask = outputs[\"input_ids\"], outputs[\"attention_mask\"]\n            # responses\n            node.agent_process.tokenizer.padding_side = \"right\"\n            response_ids = [{\"input_ids\": step_output.response_id} for agent_output in cur_agent_outputs for step_output in agent_output]\n            if pad_batch_size:\n                for _ in range(pad_batch_size):\n                    response_ids.append(response_ids[0].copy())\n            outputs = node.agent_process.tokenizer.pad(\n                response_ids,\n                padding=\"max_length\",\n                max_length=self.rollout_config.response_length,\n                return_tensors=\"pt\",\n                return_attention_mask=True,\n            )\n            response_ids, response_attention_mask = outputs[\"input_ids\"], outputs[\"attention_mask\"]\n\n            # response_mask\n            response_masks = [{\"input_ids\": step_output.response_mask} for agent_output in cur_agent_outputs for step_output in agent_output]\n            \n            if pad_batch_size:\n                for _ in range(pad_batch_size):\n                    response_masks.append({\"input_ids\":[0] * len(response_masks[0][\"input_ids\"])})\n            outputs = node.agent_process.tokenizer.pad(\n                response_masks,\n                padding=\"max_length\",\n                max_length=self.rollout_config.response_length,\n                return_tensors=\"pt\",\n                return_attention_mask=False,\n            )\n            response_mask = outputs[\"input_ids\"]\n            assert response_ids.shape == response_mask.shape, (\n                f\"mismatch in response_ids and response_mask shape: {response_ids.shape} vs {response_mask.shape}\"\n            )\n            response_mask = response_mask * response_attention_mask\n            request_ids = []\n            traj_len = []\n            traj_step = []\n            for agent_output in cur_agent_outputs:\n                traj = len(agent_output)\n                step = 0\n                for step_output in agent_output:\n                    request_ids.append(step_output.request_id)\n                    traj_len.append(traj)\n                    traj_step.append(step)\n                    if self.rollout_config.agent.rewards_with_env:\n                        metrics[f\"agent_{node.agent_group}_critic/step_{step}_rewards/mean\"].append(step_output.rewards)\n                    step = step + 1\n            if pad_batch_size:\n                for _ in range(pad_batch_size):\n                    request_ids.append(\"pad_request\")\n                    traj_len.append(1)\n                    traj_step.append(0)\n                    prompt_texts.append(prompt_texts[0])\n            input_ids = torch.cat([prompt_ids, response_ids], dim=1)\n            attention_mask = torch.cat([prompt_attention_mask, response_attention_mask], dim=1)\n            position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask\n            prefix = f\"agent_group_{node.agent_group}_\"\n            batch = TensorDict(\n                { \n                    prefix + \"prompts\": prompt_ids,  # [bsz, prompt_length]\n                    prefix + \"responses\": response_ids,  # [bsz, response_length]\n                    prefix + \"response_mask\": response_mask,  # [bsz, response_length]\n                    prefix + \"input_ids\": input_ids,  # [bsz, prompt_length + response_length]\n                    prefix + \"attention_mask\": attention_mask,  # [bsz, prompt_length + response_length]\n                    prefix + \"position_ids\": position_ids,  # [bsz, prompt_length + response_length]\n                    \n                },\n                batch_size=len(input_ids),\n            )\n            non_tensor_batch = {\n                prefix + \"request_id\": np.array(request_ids),\n                prefix + \"traj_len\": np.array(traj_len),\n                prefix + \"traj_step\": np.array(traj_step),\n                prefix + \"prompt_texts\": np.array(prompt_texts)\n            }\n            if node.node_id == self.node_queue[-1].node_id and self.rollout_config.agent.rewards_with_env:\n                reward_tensor = torch.zeros_like(batch[prefix + \"responses\"], dtype=torch.float32)\n                idx = 0\n                for agent_outputs in cur_agent_outputs:\n                    for step_output in agent_outputs:\n                        prompt_ids = batch[prefix + \"prompts\"][idx]\n                        prompt_length = prompt_ids.shape[-1]\n                        valid_response_length = batch[prefix + \"attention_mask\"][idx][prompt_length:].sum()\n                        reward_tensor[idx, valid_response_length - 1] = step_output.rewards\n                        idx = idx + 1\n                batch[prefix + \"token_level_rewards\"] = reward_tensor\n                batch[prefix + \"token_level_scores\"] = copy.deepcopy(reward_tensor)\n                \n            if node.agent_process.env:\n                node.agent_process.env_managers.clear()\n                node.agent_process.env_managers = [{}]\n            for step_id in range(self.rollout_config.multi_turn.max_assistant_turns):\n                step_key = f'agent_{node.agent_group}_critic/step_{step_id}_rewards/mean'\n                metrics[f\"agent_{node.agent_group}_critic/step_{step_id}_rewards/mean\"] = np.mean(metrics[f\"agent_{node.agent_group}_critic/step_{step_id}_rewards/mean\"]) / (step_id + 1)\n            return TensorDict(batch=batch, non_tensor_batch=non_tensor_batch, meta_info={\"metrics\": {}})\n        \n        tasks = []\n        for i in range(len(self.node_queue)):\n            cur_node = self.node_queue[i]\n            _, cur_dp_rank, cur_tp_rank, *_ = self.dag._get_node_dp_info(cur_node)\n            for len_id in range(self.rollout_config.multi_turn.max_assistant_turns):\n                metrics[f\"agent_{cur_node.agent_group}_critic/step_{len_id}_rewards/mean\"] = []\n            if cur_tp_rank == 0:\n                tasks.append(_single_postprocess(agent_outputs = agent_outputs[cur_node.node_id], node = cur_node, metrics = metrics))\n        loop = asyncio.get_event_loop()\n        datas = loop.run_until_complete(asyncio.gather(*tasks))    \n        dataproto = None\n        for data in datas:\n            if dataproto:\n                dataproto.union(data)\n            else:\n                dataproto = data\n        return dataproto\n\n    def generate_sequence(self, batch:TensorDict, timing_raw: Dict[str, float] = {}):\n        \"\"\"\n        Generate model output sequences based on the input TensorDict batch, handling both colocated and spread placement modes.\n        Args:\n            batch: Input TensorDict object containing raw data (e.g., prompts, ground truth) for sequence generation.\n            timing_raw: Dictionary to record raw timing metrics (e.g., data transfer, generation latency). Defaults to empty dict.\n        Returns:\n            Processed TensorDict object with generated sequences, metadata (including metrics), and prefix removed for downstream DAG worker.\n        \"\"\"\n        prompts = None\n        metrics = {}\n        loop = asyncio.get_event_loop()\n        entry_node = self.node_queue[0]\n        sampling_params = dict(\n            temperature=self.rollout_config.temperature,\n            top_p=self.rollout_config.top_p,\n            repetition_penalty=1.0,\n        )\n        # override sampling params for validation\n        if batch.meta_info.get(\"validate\", False):\n            sampling_params[\"top_p\"] = self.rollout_config.val_kwargs.top_p\n            sampling_params[\"temperature\"] = self.rollout_config.val_kwargs.temperature\n            \n        if self.placement_mode == 'colocate' or self.node_if_local(entry_node):\n            prompts, ground_truth = self._preprocess(batch)\n            prompts_ids = entry_node.agent_process.tokenizer.apply_chat_template(\n                        prompts,\n                        add_generation_prompt=True,\n                        tokenize=True,\n                    )\n            tasks = []\n            dp_size, dp_rank, tp_rank, *_ = self.dag._get_node_dp_info(entry_node)\n            if tp_rank == 0:\n                for i in range(len(prompts_ids)):\n                    key = self._generate_key(entry_node, entry_node, i)\n                    tasks.append(self.async_put_data(key, \n                        AgentOutput(batch_id = i, \n                                    original_prompt = prompts_ids[i], \n                                    templated_prompt = '', \n                                    should_stop = False, \n                                    response_mask = [0] * len(prompts_ids[i]), \n                                response_id = [], \n                                    request_id = uuid4().hex,\n                                    ground_truth = ground_truth[i]), \n                        dp_size, dp_size, timing_raw))\n                loop.run_until_complete(asyncio.gather(*tasks))      \n        with Timer(name=\"generate_sequences\", logger=None) as timer:\n            if self.placement_mode == 'spread':\n                # if in different GPUWorker\n                response,response_mask = loop.run_until_complete(self.generate_spread(timing_raw))\n            elif self.placement_mode == 'colocate':\n                # if in same GPUWorker\n                agent_outputs = loop.run_until_complete(self.generate_colocate(len(prompts_ids), sampling_params, timing_raw))\n        delta_time = timer.last\n        metrics[\"perf/delta_time/multi_agent_generate\"] = delta_time\n        generated_proto = self._postprocess(agent_outputs, metrics) \n          \n        # remove last node prefix, because it will be add in dagworker\n        if generated_proto:\n            generated_proto.meta_info.update({\"metrics\": metrics})\n            generated_proto = remove_prefix_from_dataproto(generated_proto, self.node_queue[-1])\n        # databuffer will reset in dagworker, so only reset internal_dict\n        # but in validate step, databuffer will not clean\n        if batch.meta_info.get(\"validate\", False):\n            dist.barrier()\n            if dist.get_rank() == 0:\n                tasks = [databuffer.reset.remote() for databuffer in self.data_buffers]\n                ray.get(tasks)\n        self.internal_data_cache.clear()\n        return generated_proto\n    \n\n \n"
  },
  {
    "path": "siirl/execution/rollout_flow/multi_agent/utils.py",
    "content": "from pydantic import BaseModel\nfrom typing import List, Optional, Union, Any\n\nclass AgentOutputStatus:\n    RUNNING = 0\n    LENGTH_FINISH = 1\n    ENV_FINISH = 2\n    Turn_FINISH = 3\n\nclass AgentOutput(BaseModel):\n    batch_id: int = -1\n    original_prompt: Optional[Union[str, List[int]]]\n    response_id: Optional[Union[str, List[int]]]\n    templated_prompt: Optional[Union[str, List[int]]]\n    should_stop: bool = False\n    response_mask: Optional[List[int]]\n    env_obs: Optional[Union[str, List[int]]] = \"\"\n    ground_truth: Any = ''\n    rewards: int = 0\n    status: str = AgentOutputStatus.RUNNING\n    turn: int = 0\n    request_id: str = \"None\"\n\n\n\n     "
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/__init__.py",
    "content": ""
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/agent_loop/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .agent_loop import AgentLoopBase, AgentLoopManager\n\n__all__ = [\"AgentLoopBase\", \"AgentLoopManager\"]\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/agent_loop/agent_loop.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport heapq\nimport logging\nimport os\nimport random\nfrom abc import ABC, abstractmethod\nfrom typing import Any, Dict, List, Type\n\nimport numpy as np\nimport ray\nimport torch\nfrom cachetools import LRUCache\nfrom omegaconf import DictConfig\nfrom pydantic import BaseModel\nfrom tensordict import TensorDict\n# Handle different tensordict versions - NonTensorData location varies\ntry:\n    from tensordict import NonTensorData\nexcept ImportError:\n    from tensordict.tensorclass import NonTensorData\nfrom transformers import AutoTokenizer\n\nfrom siirl.models.loader import load_tokenizer\nfrom siirl.utils.extras.fs import copy_to_local\nfrom siirl.engine.rollout.async_server import async_server_class\nfrom loguru import logger\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"siirl_LOGGING_LEVEL\", \"WARN\"))\n\nasync def get_ddp_world_size_rank(local_world_size, local_rank, local_parallel_size):\n    \n    ddp_world_size = local_world_size // local_parallel_size\n    ddp_rank = local_rank // local_parallel_size\n    return ddp_world_size, ddp_rank\n\n\n\nclass AsyncLLMServerManager:\n    \"\"\"\n    A class to manage multiple OpenAI compatible LLM servers. This class provides\n    - Load balance: least requests load balancing\n    - Sticky session: send multi-turn chat completions to same server for automatic prefix caching\n    \"\"\"\n\n    def __init__(self, config: DictConfig, server, max_cache_size: int = 10000):\n        \"\"\"Initialize the AsyncLLMServerManager.\n\n        Args:\n            config (DictConfig): YAML config.\n            server (vllm/sglang async engine): OpenAI compatible LLM server.\n            max_cache_size (int, optional): max cache size for request_id to server mapping. Defaults to 10000.\n        \"\"\"\n        self.config = config\n        self.server = server\n\n    async def generate(\n        self,\n        request_id,\n        *,\n        prompt_ids: List[int],\n        sampling_params: Dict[str, Any],\n    ) -> List[int]:\n        \"\"\"Generate tokens from prompt ids.\n\n        Args:\n            request_id (str): request id for sticky session.\n            prompt_ids (List[int]): List of prompt token ids.\n            sampling_params (Dict[str, Any]): Sampling parameters for the chat completion.\n\n        Returns:\n            List[int]: List of generated token ids.\n        \"\"\"\n\n        output = await self.server.generate(\n            request_id=request_id,\n            prompt_ids=prompt_ids,\n            sampling_params=sampling_params,\n        )\n        return output\n\n\nclass AgentLoopMetrics(BaseModel):\n    \"\"\"Agent loop performance metrics.\"\"\"\n\n    generate_sequences: float = 0.0\n    tool_calls: float = 0.0\n\n\nclass AgentLoopOutput(BaseModel):\n    \"\"\"Agent loop output.\"\"\"\n\n    prompt_ids: List[int]\n    response_ids: List[int]\n    response_mask: List[int]\n    num_turns: int = 0\n    metrics: AgentLoopMetrics\n\n\nclass AgentLoopBase(ABC):\n    \"\"\"An agent loop takes a input message, chat with OpenAI compatible LLM server and interact with various\n    environments.\"\"\"\n\n    _class_initialized = False\n\n    def __init__(self, config: DictConfig, server_manager: AsyncLLMServerManager, tokenizer: AutoTokenizer):\n        \"\"\"Initialize agent loop.\n\n        Args:\n            config (DictConfig): YAML config.\n            server_manager (AsyncLLMServerManager): OpenAI compatible LLM server manager.\n            tokenizer (AutoTokenizer): Tokenizer for tokenize messages.\n        \"\"\"\n        self.config = config\n        self.server_manager = server_manager\n        self.tokenizer = tokenizer\n        self.loop = asyncio.get_running_loop()\n        self.init_class(config, tokenizer)\n\n    @classmethod\n    def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer):\n        \"\"\"Initialize class state shared across all instances.\"\"\"\n        if cls._class_initialized:\n            return\n        cls._class_initialized = True\n\n    @abstractmethod\n    async def run(self, messages: List[Dict[str, Any]], sampling_params: Dict[str, Any]) -> AgentLoopOutput:\n        \"\"\"Run agent loop to interact with LLM server and environment.\n\n        Args:\n            messages (List[Dict[str, Any]]): Input messages.\n            sampling_params (Dict[str, Any]): LLM sampling params.\n\n        Returns:\n            AgentLoopOutput: Agent loop output.\n        \"\"\"\n        raise NotImplementedError\n\n\nclass AgentLoopWorker:\n    \"\"\"Agent loop worker takes a batch of messages and run each message in an agent loop.\"\"\"\n\n    def __init__(self, config: DictConfig, server_handles: List[ray.actor.ActorHandle]):\n        \"\"\"Initialize agent loop manager.\n\n        Args:\n            config (DictConfig): YAML config.\n            server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.\n        \"\"\"\n        self.config = config\n        self.server_manager = AsyncLLMServerManager(config, server_handles)\n        model_path = config.model.path\n        self.model_name = \"/\".join(model_path.split(\"/\")[-2:])\n        local_path = copy_to_local(config.model.path)\n        tokenizer_module = load_tokenizer(model_args=self.config.model)\n        self.tokenizer, self.processor = tokenizer_module[\"tokenizer\"], tokenizer_module[\"processor\"]\n\n    async def generate_sequences(self, batch: TensorDict) -> TensorDict:\n        \"\"\"Generate sequences from agent loop.\n\n        Args:\n            batch (TensorDict): Input batch.\n\n        Returns:\n            TensorDict: Output batch.\n            - prompts: [bsz, prompt_length], prompt token ids from dataset.\n            - responses: [bsz, response_length], output token ids include response tokens\n              from LLM generation and observation tokens from tool_calls.\n            - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens.\n            - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens\n              and response tokens.\n            - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens.\n            - position_ids: [bsz, prompt_length + response_length], incremental position ids.\n\n            For multi-turn conversations:\n            responses:     |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->|\n            response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0|\n        \"\"\"\n        config = self.config.rollout\n        sampling_params = dict(\n            temperature=config.temperature,\n            top_p=config.top_p,\n            repetition_penalty=1.0,\n        )\n\n        # override sampling params for validation\n        if getattr(batch.get(\"validate\"), 'data', False):\n            sampling_params[\"top_p\"] = config.val_kwargs.top_p\n            sampling_params[\"temperature\"] = config.val_kwargs.temperature\n\n        tasks = []\n        # by default, we assume it's a single turn agent\n        agent_name = self.config.rollout.agent.agent_name\n        \n        # agent_names = batch.non_tensor_batch[\"agent_name\"].repeat(n, axis=0)\n        raw_prompts = batch[\"raw_prompt\"]\n        target_size = raw_prompts.shape[0]\n        agent_names = np.full(target_size, agent_name)\n        \n        for agent_name, messages in zip(agent_names, raw_prompts):\n            tasks.append(asyncio.create_task(self._run_agent_loop(agent_name, messages.tolist(), sampling_params)))\n        outputs = await asyncio.gather(*tasks)\n        output = self._postprocess(outputs)\n        batch.update(output)\n        return batch\n\n    async def _run_agent_loop(\n        self, agent_name: str, messages: List[Dict[str, Any]], sampling_params: Dict[str, Any]\n    ) -> AgentLoopOutput:\n        agent_loop_class = self.get_agent_loop_class(agent_name)\n        agent_loop = agent_loop_class(self.config, self.server_manager, self.tokenizer)\n        output = await agent_loop.run(messages, sampling_params)\n        return output\n\n    def get_agent_loop_class(self, agent_name: str) -> Type[AgentLoopBase]:\n        # TODO: add tool agent registrary\n        from siirl.execution.rollout_flow.multiturn.agent_loop.single_turn_agent_loop import SingleTurnAgentLoop\n        from siirl.execution.rollout_flow.multiturn.agent_loop.tool_agent_loop import ToolAgentLoop\n\n        if agent_name == \"single_turn_agent\":\n            return SingleTurnAgentLoop\n        elif agent_name == \"tool_agent\":\n            return ToolAgentLoop\n        raise ValueError(f\"Unknown agent_name: {agent_name}\")\n\n    def _postprocess(self, inputs: List[AgentLoopOutput]) -> TensorDict:\n        # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py\n        # prompts: left pad\n        # responses: right pad\n        # input_ids: prompt + response\n        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]\n        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]\n\n        # prompts\n        self.tokenizer.padding_side = \"left\"\n        outputs = self.tokenizer.pad(\n            [{\"input_ids\": input.prompt_ids} for input in inputs],\n            padding=\"max_length\",\n            max_length=self.config.rollout.prompt_length,\n            return_tensors=\"pt\",\n            return_attention_mask=True,\n        )\n        prompt_ids, prompt_attention_mask = outputs[\"input_ids\"], outputs[\"attention_mask\"]\n\n        # responses\n        self.tokenizer.padding_side = \"right\"\n        outputs = self.tokenizer.pad(\n            [{\"input_ids\": input.response_ids} for input in inputs],\n            padding=\"max_length\",\n            max_length=self.config.rollout.response_length,\n            return_tensors=\"pt\",\n            return_attention_mask=True,\n        )\n        response_ids, response_attention_mask = outputs[\"input_ids\"], outputs[\"attention_mask\"]\n\n        # response_mask\n        outputs = self.tokenizer.pad(\n            [{\"input_ids\": input.response_mask} for input in inputs],\n            padding=\"max_length\",\n            max_length=self.config.rollout.response_length,\n            return_tensors=\"pt\",\n            return_attention_mask=False,\n        )\n        response_mask = outputs[\"input_ids\"]\n        assert response_ids.shape == response_mask.shape, (\n            f\"mismatch in response_ids and response_mask shape: {response_ids.shape} vs {response_mask.shape}\"\n        )\n        response_mask = response_mask * response_attention_mask\n\n        input_ids = torch.cat([prompt_ids, response_ids], dim=1)\n        attention_mask = torch.cat([prompt_attention_mask, response_attention_mask], dim=1)\n        position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask\n\n        batch = TensorDict(\n            {\n                \"prompts\": prompt_ids,  # [bsz, prompt_length]\n                \"responses\": response_ids,  # [bsz, response_length]\n                \"response_mask\": response_mask,  # [bsz, response_length]\n                \"input_ids\": input_ids,  # [bsz, prompt_length + response_length]\n                \"attention_mask\": attention_mask,  # [bsz, prompt_length + response_length]\n                \"position_ids\": position_ids,  # [bsz, prompt_length + response_length]\n            },\n            batch_size=len(input_ids),\n        )\n\n        num_turns = np.array([input.num_turns for input in inputs], dtype=np.int32)\n        metrics = [input.metrics.model_dump() for input in inputs]\n        batch[\"__num_turns__\"] = NonTensorData(num_turns)\n        batch[\"metrics\"] = NonTensorData(metrics)\n        return batch\n\nclass AgentLoopManager:\n    \"\"\"Agent loop manager that manages a group of agent loop workers.\"\"\"\n\n    def __init__(self, config: DictConfig, cur_dp_rank, name_prefix, engine, zmq_addresses:List):\n        \"\"\"Initialize agent loop manager.\n\n        Args:\n            config (DictConfig): trainer config.\n            worker_group (RayWorkerGroup): ActorRolloutRef worker group.\n        \"\"\"\n        self.config = config\n        self.cur_dp_rank = cur_dp_rank\n        self.name_prefix = name_prefix\n        self.engine = engine\n        self.zmq_addresses = zmq_addresses\n        self._initialize_llm_servers()\n        self._init_agent_loop_workers()\n\n    def _initialize_llm_servers(self):\n        self.rollout_tp_size = self.config.rollout.tensor_model_parallel_size\n        # self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size\n        # in siirl, every dp has private rollout engine\n        self.rollout_dp_size = 1\n        if self.config.rollout.agent.custom_async_server:\n            server_class = async_server_class(\n                rollout_backend=self.config.rollout.name,\n                rollout_backend_module=self.config.rollout.agent.custom_async_server.path,\n                rollout_backend_class=self.config.rollout.agent.custom_async_server.name,\n            )\n        else:\n            server_class = async_server_class(rollout_backend=self.config.rollout.name)\n\n        self.async_llm_server = server_class(self.config, self.engine, self.zmq_addresses)\n        self.async_llm_server.init_engine()\n    def _init_agent_loop_workers(self):\n        self.agent_loop_worker = AgentLoopWorker(self.config, self.async_llm_server)\n            \n\n    async def generate_sequences(self, prompts: TensorDict) -> TensorDict:\n        \"\"\"Split input batch and dispatch to agent loop workers.\n\n        Args:\n            prompts (TensorDict): Input batch.\n\n        Returns:\n            TensorDict: Output batch.\n        \"\"\"\n        if self.config.rollout.free_cache_engine:\n            await self.wake_up()\n        output = await self.agent_loop_worker.generate_sequences(prompts)\n        if self.config.rollout.free_cache_engine:\n            await self.sleep()\n\n        # calculate performance metrics\n        metrics = [output[\"metrics\"]]  # List[List[Dict[str, str]]]\n        timing = self._performance_metrics(metrics, output)\n    \n        output[\"metrics\"] = NonTensorData(timing)\n        return output\n\n    def _performance_metrics(self, metrics: List[List[Dict[str, str]]], output: TensorDict) -> Dict[str, float]:\n        timing = {}\n        t_generate_sequences = np.array([metric[\"generate_sequences\"] for chunk in metrics for metric in chunk])\n        t_tool_calls = np.array([metric[\"tool_calls\"] for chunk in metrics for metric in chunk])\n        timing[\"agent_loop/generate_sequences/min\"] = t_generate_sequences.min()\n        timing[\"agent_loop/generate_sequences/max\"] = t_generate_sequences.max()\n        timing[\"agent_loop/generate_sequences/mean\"] = t_generate_sequences.mean()\n        timing[\"agent_loop/tool_calls/min\"] = t_tool_calls.min()\n        timing[\"agent_loop/tool_calls/max\"] = t_tool_calls.max()\n        timing[\"agent_loop/tool_calls/mean\"] = t_tool_calls.mean()\n\n        # batch sequence generation is bounded by the slowest sample\n        slowest = np.argmax(t_generate_sequences + t_tool_calls)\n        attention_mask = output[\"attention_mask\"][slowest]\n        prompt_length = output[\"prompts\"].shape[1]\n        timing[\"agent_loop/slowest/generate_sequences\"] = t_generate_sequences[slowest]\n        timing[\"agent_loop/slowest/tool_calls\"] = t_tool_calls[slowest]\n        timing[\"agent_loop/slowest/prompt_length\"] = attention_mask[:prompt_length].sum().item()\n        timing[\"agent_loop/slowest/response_length\"] = attention_mask[prompt_length:].sum().item()\n\n        return timing\n\n    async def wake_up(self):\n        \"\"\"Wake up all rollout server instances.\"\"\"\n        self.async_llm_server.wake_up()\n\n    async def sleep(self):\n        \"\"\"Sleep all rollout server instances.\"\"\"\n        self.async_llm_server.sleep()\n\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/agent_loop/single_turn_agent_loop.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport logging\nimport os\nfrom typing import Any, Dict, List\nfrom uuid import uuid4\n\nfrom siirl.execution.rollout_flow.multiturn.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput\nfrom contextlib import contextmanager\nimport time\nimport torch\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"SIIRL_LOGGING_LEVEL\", \"WARN\"))\n\n@contextmanager\ndef _timer(name: str, timing_dict: dict):\n    \"\"\"A context manager to measure execution time of a code block.\"\"\"\n    # if self.enable_perf:\n        # torch.cuda.synchronize()\n    start_time = time.perf_counter()\n    yield\n    # if self.enable_perf:\n    #     torch.cuda.synchronize()\n    end_time = time.perf_counter()\n    timing_dict[name] = timing_dict.get(name, 0)  + end_time - start_time\n\n\nclass SingleTurnAgentLoop(AgentLoopBase):\n    \"\"\"Naive agent loop that only do single turn chat completion.\"\"\"\n\n    def __init__(self, config, server_manager, tokenizer):\n        super().__init__(config, server_manager, tokenizer)\n        self.prompt_length = config.rollout.prompt_length\n        self.response_length = config.rollout.response_length\n\n    async def run(self, messages: List[Dict[str, Any]], sampling_params: Dict[str, Any]) -> AgentLoopOutput:\n        metrics = {}\n        request_id = uuid4().hex\n        prompt_ids = await self.loop.run_in_executor(\n            None, lambda: self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)\n        )\n\n        with _timer(\"generate_sequences\", metrics):\n            response_ids = await self.server_manager.generate(\n                request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params\n            )\n        response_mask = [1] * len(response_ids)\n\n        output = AgentLoopOutput(\n            prompt_ids=prompt_ids,\n            response_ids=response_ids[: self.response_length],\n            response_mask=response_mask[: self.response_length],\n            num_turns=2,\n            metrics=metrics,\n        )\n        return output\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/agent_loop/tool_agent_loop.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport asyncio\nimport json\nimport logging\nimport os\nfrom abc import ABC, abstractmethod\nfrom typing import Any, Dict, List\nfrom uuid import uuid4\n\nimport regex as re\nfrom pydantic import BaseModel\n\nfrom siirl.execution.rollout_flow.multiturn.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput\nfrom siirl.execution.rollout_flow.multiturn.tools.utils.tool_registry import initialize_tools_from_config\nfrom contextlib import contextmanager\nimport time\nimport torch\n\nlogger = logging.getLogger(__file__)\nlogger.setLevel(os.getenv(\"SIIRL_LOGGING_LEVEL\", \"WARN\"))\n@contextmanager\ndef _timer(name: str, timing_dict: dict):\n    \"\"\"A context manager to measure execution time of a code block.\"\"\"\n    # if self.enable_perf:\n    #     torch.cuda.synchronize()\n    start_time = time.perf_counter()\n    yield\n    # if self.enable_perf:\n        # torch.cuda.synchronize()\n    end_time = time.perf_counter()\n    timing_dict[name] = timing_dict.get(name, 0)  + end_time - start_time\n    \n\nclass FunctionCall(BaseModel):\n    arguments: str\n    \"\"\"\n    The arguments to call the function with, as generated by the model in JSON\n    format. Note that the model does not always generate valid JSON, and may\n    hallucinate parameters not defined by your function schema. Validate the\n    arguments in your code before calling your function.\n    \"\"\"\n\n    name: str\n    \"\"\"The name of the function to call.\"\"\"\n\n\nclass ToolParser(ABC):\n    @abstractmethod\n    async def extract_tool_calls(self, responses_ids: List[int], prompt_ids) -> List[FunctionCall]:\n        \"\"\"Extract tool calls from the responses.\n\n        Args:\n            responses_ids (List[int]): The ids of the responses.\n\n        Returns:\n            List[FunctionCall]: The extracted tool calls.\n        \"\"\"\n        raise NotImplementedError\n\n\nclass HermesToolParser(ToolParser):\n    \"\"\"Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py\"\"\"\n\n    def __init__(self, tokenizer) -> None:\n        self.tokenizer = tokenizer\n\n        self.tool_call_start_token: str = \"<tool_call>\"\n        self.tool_call_end_token: str = \"</tool_call>\"\n        self.tool_call_regex = re.compile(r\"<tool_call>(.*?)</tool_call>\", re.DOTALL)\n        # list(re.finditer(r\"<tool_call>(.*?)</tool_call>\", text, re.DOTALL))\n    async def extract_tool_calls(self, responses_ids: List[int], prompt_ids) -> List[FunctionCall]:\n        loop = asyncio.get_running_loop()\n        text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids)\n        if self.tool_call_start_token not in text or self.tool_call_end_token not in text:\n            return []\n       \n        matches = self.tool_call_regex.findall(text)\n        function_calls = []\n        for match in matches:\n            try:\n                function_call = json.loads(match)\n                name, arguments = function_call[\"name\"], function_call[\"arguments\"]\n                function_calls.append(FunctionCall(name=name, arguments=json.dumps(arguments, ensure_ascii=False)))\n            except Exception as e:\n                logger.error(f\"Failed to decode tool call: {e}\")\n        return function_calls\n\n\nclass ToolAgentLoop(AgentLoopBase):\n    def __init__(self, config, server_manager, tokenizer):\n        super().__init__(config, server_manager, tokenizer)\n\n    @classmethod\n    def init_class(cls, config, tokenizer):\n        if cls._class_initialized:\n            return\n        cls._class_initialized = True\n        print(\"Performing class-level ToolAgentLoop initialization\")\n\n        # Initialize tools from config file\n        cls.tokenizer = tokenizer\n        cls.max_user_turns = config.rollout.multi_turn.max_user_turns\n        cls.max_assistant_turns = config.rollout.multi_turn.max_assistant_turns\n        cls.max_parallel_calls = config.rollout.multi_turn.max_parallel_calls\n        cls.max_tool_response_length = config.rollout.multi_turn.max_tool_response_length\n        cls.tool_response_truncate_side = config.rollout.multi_turn.tool_response_truncate_side\n        tool_config_path = config.rollout.multi_turn.tool_config_path\n        tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else []\n        cls.tools = {tool.name: tool for tool in tool_list}\n        cls.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list]\n        cls.tool_parser = cls.get_tool_parser(config.rollout.multi_turn.format)\n        print(f\"Initialized tools: {cls.tools}\")\n\n        cls.prompt_length = config.rollout.prompt_length\n        cls.response_length = config.rollout.response_length\n        cls.system_prompt = tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)\n\n    async def run(self, messages: List[Dict[str, Any]], sampling_params: Dict[str, Any]) -> AgentLoopOutput:\n        metrics = {}\n        request_id = uuid4().hex\n        prompt_ids = await self.loop.run_in_executor(\n            None,\n            lambda: self.tokenizer.apply_chat_template(\n                messages, tools=self.tool_schemas, add_generation_prompt=True, tokenize=True\n            ),\n        )\n        response_mask = []\n\n        user_turns, assistant_turns = 0, 0\n        while True:\n            with _timer(\"generate_sequences\", metrics):\n                response_ids = await self.server_manager.generate(\n                    request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params\n                )\n            prompt_ids += response_ids\n            \n            response_mask += [1] * len(response_ids)\n            assistant_turns += 1\n            # reach max response length\n            if len(response_mask) >= self.response_length:\n                break\n\n            # reach max assistant turns\n            if self.max_assistant_turns and assistant_turns >= self.max_assistant_turns:\n                break\n\n            # reach max user turns\n            if self.max_user_turns and user_turns >= self.max_user_turns:\n                break\n\n            # no tool calls\n            tool_calls = await self.tool_parser.extract_tool_calls(response_ids, prompt_ids)\n            if not tool_calls:\n                break\n\n            # call tools\n            tasks = []\n            for tool_call in tool_calls[: self.max_parallel_calls]:\n                tasks.append(self._call_tool(tool_call))\n            with _timer(\"tool_calls\", metrics):\n                tool_responses = await asyncio.gather(*tasks)\n            if any(isinstance(item, Exception) for item in tool_responses):\n                break\n\n            # append tool_response_ids\n            tool_response_ids = await self.loop.run_in_executor(\n                None,\n                lambda messages=tool_responses: self.tokenizer.apply_chat_template(\n                    messages, add_generation_prompt=True, tokenize=True\n                ),\n            )\n            tool_response_ids = tool_response_ids[len(self.system_prompt) :]\n\n            # NOTE: last turn should not be user turn, or the EOS token reward\n            # can't be propagated to previous token in GAE.\n            if len(response_mask) + len(tool_response_ids) >= self.response_length:\n                break\n\n            prompt_ids += tool_response_ids\n            \n            response_mask += [0] * len(tool_response_ids)\n            user_turns += 1\n\n        response_ids = prompt_ids[-len(response_mask) :]\n        prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]\n\n        output = AgentLoopOutput(\n            prompt_ids=prompt_ids,\n            response_ids=response_ids[: self.response_length],\n            response_mask=response_mask[: self.response_length],\n            num_turns=user_turns + assistant_turns + 1,\n            metrics=metrics,\n        )\n        return output\n\n    async def _call_tool(self, tool_call: FunctionCall) -> Dict[str, str]:\n        \"\"\"Call tool and return tool response.\"\"\"\n        tool, instance_id = None, None\n        try:\n            # TODO: append malformed tool_call to the prompt: invalid function name or arguments\n            tool_name = tool_call.name\n            tool_args = json.loads(tool_call.arguments)\n            tool = self.tools[tool_name]\n\n            instance_id = await tool.create()\n            tool_response, tool_reward_score, tool_metrics = await tool.execute(instance_id, tool_args)\n        except Exception as e:\n            logger.exception(f\"Error when executing tool: {e}\")\n            return e\n        finally:\n            if tool and instance_id:\n                await tool.release(instance_id)\n\n        if len(tool_response) > self.max_tool_response_length:\n            if self.tool_response_truncate_side == \"left\":\n                tool_response = tool_response[: self.max_tool_response_length] + \"...(truncated)\"\n            elif self.tool_response_truncate_side == \"right\":\n                tool_response = \"(truncated)...\" + tool_response[-self.max_tool_response_length :]\n            else:\n                length = self.max_tool_response_length // 2\n                tool_response = tool_response[:length] + \"...(truncated)...\" + tool_response[-length:]\n\n        return {\n            \"role\": \"tool\",\n            \"content\": tool_response,\n        }\n\n    @classmethod\n    def get_tool_parser(cls, name: str) -> ToolParser:\n        tool_parsers = {\n            \"hermes\": HermesToolParser,\n        }\n        if name not in tool_parsers:\n            raise ValueError(f\"Unknown tool parser: {name}\")\n        return tool_parsers[name](cls.tokenizer)\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/interactions/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/interactions/base.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Any, Dict, List, Optional, Tuple\nfrom uuid import uuid4\n\n\nclass BaseInteraction:\n    def __init__(self, config: Dict[str, Any]):\n        self.config = config\n        self.name: str = config.get(\"name\", \"interaction_agent\")  # More general agent default role name\n\n    async def start_interaction(self, instance_id: Optional[str] = None, **kwargs) -> str:\n        \"\"\"Create a tool instance.\n\n        Args:\n            instance_id: The instance id of the tool.\n\n        Returns:\n            The instance id of the tool.\n        \"\"\"\n        if instance_id is None:\n            return str(uuid4())\n        else:\n            return instance_id\n\n    async def generate_response(\n        self, instance_id: str, messages: List[Dict[str, Any]], **kwargs\n    ) -> Tuple[bool, str, float, Dict[str, Any]]:  # More clear response generation method\n        \"\"\"\n        Generates a response for the current turn of interaction.\n        Returns a tuple containing:\n        - should_terminate_sequence (bool): True if the interaction sequence should end.\n        - response_content (str): The textual content of the response.\n        - current_turn_score (float): The score for this specific turn/response.\n        - additional_data (dict): Any extra information or metadata.\n        \"\"\"\n        should_terminate_sequence: bool = False  # if True, end rollout\n        response_content: str = \"Your current result seems acceptable.\"\n        current_turn_score: float = 0.8\n        additional_data: Dict[str, Any] = {}\n        return should_terminate_sequence, response_content, current_turn_score, additional_data\n\n    async def calculate_score(self) -> float:  # More clear score calculation method\n        \"\"\"\n        Calculates a score for the interaction,\n        potentially considering aspects like partial exposure & in-context task switching.\n        should be invoke at turn-level\n        \"\"\"\n        # ...implement the logic to calculate turn-level score...\n        score = 0.0\n        return score\n\n    async def finalize_interaction(self) -> None:  # More clear interaction end and resource release method\n        \"\"\"\n        Finalizes the interaction session and releases any associated state or resources.\n        Simulates: release state\n        \"\"\"\n        # ...implement the logic to release state...\n        pass\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/interactions/gsm8k_interaction.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom loguru import logger\nimport os\nfrom typing import Any, Dict, List, Optional, Tuple\nfrom uuid import uuid4\n\nfrom siirl.utils.reward_score import gsm8k\n\nfrom .base import BaseInteraction\n\n\n\nclass Gsm8kInteraction(BaseInteraction):\n    \"\"\"A demo interaction for calculating the reward of gsm8k.\n\n    - `start_interaction`: start a interaction instance for a trajectory.\n    - `generate_response`: generate the response of the user.\n    - `calculate_score`: calculate the score of the interaction.\n    - `finalize_interaction`: finalize the interaction instance.\n    \"\"\"\n\n    def __init__(self, config: dict):\n        super().__init__(config)\n        self._instance_dict = {}\n\n    async def start_interaction(\n        self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs\n    ) -> str:\n        if instance_id is None:\n            instance_id = str(uuid4())\n        self._instance_dict[instance_id] = {\n            \"response\": \"\",\n            \"ground_truth\": ground_truth,\n            \"reward\": 0.0,\n        }\n        return instance_id\n\n    async def generate_response(\n        self, instance_id: str, messages: List[Dict[str, Any]], **kwargs\n    ) -> Tuple[bool, str, float, dict]:\n        content = \"\"\n        for i in range(len(messages) - 1, -1, -1):\n            item = messages[i]\n            if item.get(\"role\") == \"user\":\n                content = item.get(\"content\")\n                break\n\n        if content and content.startswith(\"#### \"):\n            self._instance_dict[instance_id][\"response\"] = content\n        else:\n            self._instance_dict[instance_id][\"response\"] = \"#### \" + (content or \"\")\n\n        reward = await self.calculate_score(instance_id)\n        if reward == 1.0:\n            response = \"Your response is correct!\"\n            should_terminate_sequence = True\n        else:\n            response = \"Your response is incorrect! You need to reflect on your answer and try again.\"\n            should_terminate_sequence = False\n\n        return should_terminate_sequence, response, reward, {}\n\n    async def calculate_score(self, instance_id: str, **kwargs) -> float:\n        return gsm8k.compute_score(\n            self._instance_dict[instance_id][\"response\"],\n            self._instance_dict[instance_id][\"ground_truth\"],\n            method=\"flexible\",\n            format_score=0.0,\n            score=1.0,\n        )\n\n    async def finalize_interaction(self, instance_id: str, **kwargs) -> None:\n        del self._instance_dict[instance_id]\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/interactions/utils/__init__.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/interactions/utils/interaction_registry.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib.util\nimport logging\nimport os\nimport sys\n\nfrom omegaconf import OmegaConf\nfrom loguru import logger\n\n\ndef get_interaction_class(cls_name):\n    \"\"\"Dynamically import and return the interaction class.\"\"\"\n    module_name, class_name = cls_name.rsplit(\".\", 1)\n    if module_name not in sys.modules:\n        spec = importlib.util.find_spec(module_name)\n        module = importlib.util.module_from_spec(spec)\n        sys.modules[module_name] = module\n        spec.loader.exec_module(module)\n    else:\n        module = sys.modules[module_name]\n\n    interaction_cls = getattr(module, class_name)\n    return interaction_cls\n\n\ndef initialize_interactions_from_config(interaction_config_file):\n    \"\"\"Initialize interactions from configuration file.\n\n    Args:\n        interaction_config_file: Path to the interaction configuration file.\n\n    Returns:\n        dict: A dictionary mapping interaction names to BaseInteraction instances.\n    \"\"\"\n    interaction_config = OmegaConf.load(interaction_config_file)\n    interaction_map = {}\n\n    for interaction_item in interaction_config.interaction:\n        cls_name = interaction_item.class_name\n        interaction_cls = get_interaction_class(cls_name)\n\n        # Extract config and name\n        config = OmegaConf.to_container(interaction_item.config, resolve=True)\n\n        # Get the interaction name - either from config or derive from class name\n        name = interaction_item.get(\"name\", None)\n        if name is None:\n            # If no name is specified, use the class name as default\n            class_simple_name = cls_name.split(\".\")[-1]\n            # Remove \"Interaction\" suffix if present, otherwise use full class name\n            if class_simple_name.endswith(\"Interaction\"):\n                name = class_simple_name[:-11].lower()  # Remove \"Interaction\" (11 chars)\n            else:\n                name = class_simple_name.lower()\n\n        # Check for duplicate names\n        if name in interaction_map:\n            raise ValueError(f\"Duplicate interaction name '{name}' found. Each interaction must have a unique name.\")\n\n        # Inject the name into the config\n        config[\"name\"] = name\n\n        # Create the interaction instance\n        interaction = interaction_cls(config=config)\n        interaction_map[name] = interaction\n\n        logger.info(f\"Initialized interaction '{name}' with class '{cls_name}'\")\n\n    return interaction_map\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/tools/__init__.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/tools/base_tool.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nfrom typing import Any, Optional, Tuple\nfrom uuid import uuid4\n\nfrom .schemas import OpenAIFunctionToolSchema\n\n\nclass BaseTool:\n    \"\"\"Base class for tools.\n\n    A tool should support the following methods:\n\n    - `to_openai_function_tool_schema`: return the tool schema in OpenAI format.\n    - `create`: create a tool instance for a trajectory.\n    - `execute`: execute the tool.\n    - `calc_reward`: calculate the reward respect to tool state.\n    - `release`: release the tool instance.\n    \"\"\"\n\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        self.config = config\n        self.tool_schema = tool_schema or self.get_openai_tool_schema()\n        assert self.tool_schema is not None, \"Tool schema is not set!\"\n        self.name = self.tool_schema.function.name\n        print(json.dumps(self.tool_schema.model_dump(exclude_unset=True, exclude_none=True), indent=2))\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        return self.tool_schema\n\n    async def create(self, instance_id: Optional[str] = None, **kwargs) -> str:\n        \"\"\"Create a tool instance.\n\n        Args:\n            instance_id: The instance id of the tool.\n\n        Returns:\n            The instance id of the tool.\n        \"\"\"\n        if instance_id is None:\n            return str(uuid4())\n        else:\n            return instance_id\n\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:\n        \"\"\"Execute the tool.\n\n        Args:\n            instance_id: The instance id of the tool.\n            parameters: The json string of the parameters of the tool.\n\n        Returns: tool_response, tool_reward_score, tool_metrics\n            tool_response: The response str of the tool.\n            tool_reward_score: The step reward score of the tool.\n            tool_metrics: The metrics of the tool.\n        \"\"\"\n        return \"Updated the tool state.\", 0.0, {}\n\n    async def calc_reward(self, instance_id: str, **kwargs) -> float:\n        \"\"\"Calculate the reward of the tool.\n\n        Args:\n            instance_id: The instance id of the tool.\n\n        Returns:\n            The reward of the tool.\n        \"\"\"\n        return 0.0\n\n    async def release(self, instance_id: str, **kwargs) -> None:\n        \"\"\"Release the tool instance.\n\n        Args:\n            instance_id: The instance id of the tool.\n        \"\"\"\n        pass\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/tools/geo3k_tool.py",
    "content": "# Copyright 2023-2025 SGLang Team\n# Copyright Amazon.com, Inc. or its affiliates.\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nfrom typing import Any, Optional, Tuple\nfrom uuid import uuid4\n\nfrom siirl.utils.reward_score import geo3k\n\nfrom .base_tool import BaseTool\nfrom .schemas import OpenAIFunctionToolSchema\nfrom loguru import logger\n\nclass Geo3kTool(BaseTool):\n    \"\"\"A demo tool for calculating the reward of geo3k.\n    - `to_openai_function_tool_schema`: return the tool schema in OpenAI format.\n    - `create`: create a tool instance for a trajectory.\n    - `execute`: execute the tool.\n    - `calc_reward`: calculate the reward respect to tool state.\n    - `release`: release the tool instance.\n    \"\"\"\n\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        \"\"\"\n        _tool_schema = OpenAIFunctionToolSchema.model_validate({\n            \"type\": \"function\",\n            \"function\": {\n                \"name\": \"calc_geo3k_reward\",\n                \"description\": \"A tool for calculating the reward of geo3k\",\n                \"parameters\": {\n                    \"type\": \"object\",\n                    \"properties\": {\n                        \"answer\": {\n                            \"type\": \"string\",\n                            \"description\": \"The answer to the question, enclosed in \\\\boxed{}\",\n                        },\n                    },\n                    \"required\": [\"answer\"],\n                },\n            }\n        })\n        \"\"\"\n        super().__init__(config, tool_schema)\n        self._instance_dict = {}\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        return self.tool_schema\n\n    async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str:\n        if instance_id is None:\n            instance_id = str(uuid4())\n        self._instance_dict[instance_id] = {\n            \"response\": \"\",\n            \"ground_truth\": ground_truth,\n            \"reward\": 0.0,\n        }\n        return instance_id\n\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:\n        answer = parameters.get(\"answer\", \"\")\n        if not isinstance(answer, str):\n            answer = str(answer)\n        print(\"instance_id\", instance_id)\n        print(\"self._instance_dict key\", self._instance_dict.keys())\n        self._instance_dict[instance_id][\"response\"] = answer\n        reward = await self.calc_reward(instance_id)\n        # penalty for non improved answer submission\n        tool_reward = 0.0 if reward > self._instance_dict[instance_id][\"reward\"] else -0.05\n        # update the reward\n        self._instance_dict[instance_id][\"reward\"] = reward\n        return f\"Current parsed {answer=} {reward=}\", tool_reward, {}\n\n    async def calc_reward(self, instance_id: str, **kwargs) -> float:\n        return geo3k.compute_score(\n            self._instance_dict[instance_id][\"response\"],\n            self._instance_dict[instance_id][\"ground_truth\"],\n            use_boxed=False,\n            format_score=0.0,\n        )\n\n    async def release(self, instance_id: str, **kwargs) -> None:\n        del self._instance_dict[instance_id]\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/tools/gsm8k_tool.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nfrom typing import Any, Optional, Tuple\nfrom uuid import uuid4\n\nfrom siirl.utils.reward_score import gsm8k\n\nfrom .base_tool import BaseTool\nfrom .schemas import OpenAIFunctionToolSchema\nfrom loguru import logger\n\n\n\nclass Gsm8kTool(BaseTool):\n    \"\"\"A demo tool for calculating the reward of gsm8k.\n\n    - `to_openai_function_tool_schema`: return the tool schema in OpenAI format.\n    - `create`: create a tool instance for a trajectory.\n    - `execute`: execute the tool.\n    - `calc_reward`: calculate the reward respect to tool state.\n    - `release`: release the tool instance.\n    \"\"\"\n\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        \"\"\"\n        _tool_schema = OpenAIFunctionToolSchema.model_validate({\n            \"type\": \"function\",\n            \"function\": {\n                \"name\": \"calc_gsm8k_reward\",\n                \"description\": \"A tool for calculating the reward of gsm8k\",\n                \"parameters\": {\n                    \"type\": \"object\",\n                    \"properties\": {\n                        \"answer\": {\n                            \"type\": \"string\",\n                            \"description\": \"The answer to the question\",\n                        },\n                    },\n                    \"required\": [\"answer\"],\n                },\n            }\n        })\n        \"\"\"\n        super().__init__(config, tool_schema)\n        self._instance_dict = {}\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        return self.tool_schema\n\n    async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str:\n        if instance_id is None:\n            instance_id = str(uuid4())\n        self._instance_dict[instance_id] = {\n            \"response\": \"\",\n            \"ground_truth\": ground_truth,\n            \"reward\": 0.0,\n        }\n        return instance_id\n\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:\n        answer = parameters.get(\"answer\", \"\")\n        if not isinstance(answer, str):\n            answer = str(answer)\n\n        if answer.startswith(\"#### \"):\n            self._instance_dict[instance_id][\"response\"] = answer\n        else:\n            self._instance_dict[instance_id][\"response\"] = \"#### \" + answer\n\n        reward = await self.calc_reward(instance_id)\n        # penalty for non improved answer submission\n        tool_reward = 0.0 if reward > self._instance_dict[instance_id][\"reward\"] else -0.05\n        # update the reward\n        self._instance_dict[instance_id][\"reward\"] = reward\n\n        return f\"Current parsed {answer=} {reward=}\", tool_reward, {}\n\n    async def calc_reward(self, instance_id: str, **kwargs) -> float:\n        return gsm8k.compute_score(\n            self._instance_dict[instance_id][\"response\"],\n            self._instance_dict[instance_id][\"ground_truth\"],\n            method=\"flexible\",\n            format_score=0.0,\n            score=1.0,\n        )\n\n    async def release(self, instance_id: str, **kwargs) -> None:\n        del self._instance_dict[instance_id]\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/tools/mcp_base_tool.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nfrom typing import Any, Optional, Tuple\nfrom uuid import uuid4\n\nfrom fastmcp.exceptions import ClientError\n\nfrom siirl.execution.rollout_flow.multiturn.tools.utils.mcp_clients.McpClientManager import ClientManager\n\nfrom .base_tool import BaseTool\nfrom .schemas import OpenAIFunctionToolSchema\nfrom loguru import logger\n\nclass MCPBaseTool(BaseTool):\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        super().__init__(config, tool_schema)\n        self._instance_dict = {}\n        self.timeout = config.get(\"timeout\", 30)\n\n        # TODO(hechanghao): create a global client manager to manage the rate limit, client and pool\n        logger.info(f\"Initialized MCPBaseTool with config: {config}\")\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        \"\"\"Return the OpenAI tool schema.\"\"\"\n        return self.tool_schema\n\n    async def create(self, instance_id: Optional[str] = None, **kwargs) -> str:\n        \"\"\"Create a tool instance.\n\n        Args:\n            instance_id: The instance id of the tool.\n\n        Returns:\n            The instance id of the tool.\n        \"\"\"\n        if instance_id is None:\n            instance_id = str(uuid4())\n        self._instance_dict[instance_id] = {\n            \"response\": \"\",\n            \"reward\": [],\n        }\n        return instance_id\n\n    async def _call_tool(self, instance_id, parameters) -> Tuple[str, dict]:\n        err_msg = \"\"\n        try:\n            call_tool_result = await ClientManager.call_tool(self.name, parameters, self.timeout)\n        except ClientError as e:\n            err_msg = f\"\\n Tool call failed: {e}\"\n        except ConnectionError as e:\n            err_msg = f\"\\n Connection failed: {e}\"\n        except Exception as e:\n            err_msg = f\"\\n An unexpected error occurred: {e}\"\n\n        logger.debug(f\"Tool result for instance {instance_id} with tool {self.name}: {call_tool_result.content}\")\n        result, metadata = self._parse_tool_result(call_tool_result.content)\n        metadata[\"api_request_error\"] += err_msg\n        return result, metadata\n\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:\n        if self.name == \"\" or self.name is None or parameters is None:\n            error_msg = \"Error: 'parameters' is missing or empty.\"\n            logger.error(f\"[MCPTool] {error_msg} Received tool name: {self.name}, parameters: {parameters}\")\n            return json.dumps({\"result\": error_msg}), 0.0, {}\n\n        try:\n            result_text, metadata = await self._call_tool(instance_id, parameters)\n\n            # Store results in instance dictionary\n            self._instance_dict[instance_id][\"reward\"].append(result_text.strip())\n\n            # Convert metadata to metrics\n            metrics = {\n                \"query_count\": metadata.get(\"query_count\", 0),\n                \"status\": metadata.get(\"status\", \"unknown\"),\n                \"total_results\": metadata.get(\"total_results\", 0),\n                \"api_request_error\": metadata.get(\"api_request_error\"),\n            }\n\n            return result_text, 0.0, metrics\n\n        except Exception as e:\n            error_result = json.dumps({\"result\": f\"Tool execution failed: {e}\"})\n            logger.error(f\"[MCPBaseTool] Execution failed: {e}\")\n            return error_result, 0.0, {\"error\": str(e)}\n\n    async def calc_reward(self, instance_id: str, **kwargs) -> str:\n        return self._instance_dict[instance_id][\"reward\"]\n\n    async def release(self, instance_id: str, **kwargs) -> None:\n        if instance_id in self._instance_dict:\n            del self._instance_dict[instance_id]\n\n    def _parse_tool_result(self, content: list) -> Tuple[str, dict]:\n        tools_content = [part.text for part in filter(lambda x: x.type == \"text\", content)]\n        return \" \".join(tools_content), {}\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/tools/mcp_search_tool.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport re\nfrom typing import Tuple\n\nfrom siirl.execution.rollout_flow.multiturn.tools.mcp_base_tool import MCPBaseTool\n\nfrom loguru import logger\n\n\nclass MCPSearchTool(MCPBaseTool):\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        super().__init__(config, tool_schema)\n\n    def _parse_tool_result(self, content: list) -> Tuple[str, dict]:\n        res = \"\"\n        res_cnt = 0\n        query_list = []\n        metadata = {\n            \"api_request_error\": \"\",\n            \"status\": \"unknown\",\n            \"total_results\": 0,\n        }\n        try:\n            for part in content:\n                if part.type != \"text\":\n                    continue\n                text = part.text.replace(\"'\", '\"')\n                query_match = re.search(r'query\"\\s*:\\s*\"([^\"]+)\"', text)\n                query = query_match.group(1) if query_match else \"\"\n                query_list.append(query)\n\n                title_matches = re.findall(r'\"title\"\\s*:', text)\n                title_count = len(title_matches)\n\n                results_match = re.search(r'\"results\"\\s*:\\s*(\\[.*?\\])', text, re.DOTALL)\n                results_content = results_match.group(1) if results_match else \"\"\n\n                res += results_content\n                res_cnt += title_count\n        except json.JSONDecodeError:\n            err_msg = \"json parse error.\"\n            logger.error(err_msg)\n            metadata[\"api_request_error\"] = err_msg\n            metadata[\"status\"] = \"error\"\n\n        # update metadata\n        metadata[\"status\"] = \"success\"\n        metadata[\"queries\"] = query_list\n        metadata[\"query_count\"] = len(query_list)\n        metadata[\"total_results\"] = res_cnt\n        return res, metadata\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/tools/sandbox_fusion_tools.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport threading\nfrom contextlib import ExitStack\nfrom enum import Enum\nfrom typing import Any, Callable, Optional, Tuple, TypeVar\nfrom uuid import uuid4\n\nimport ray\nimport ray.actor\nimport ray.util.multiprocessing\n\nfrom siirl.execution.rollout_flow.multiturn.tools.base_tool import BaseTool\nfrom siirl.utils.reward_score.sandbox_fusion.utils import _process_single_case\n\nfrom .schemas import OpenAIFunctionToolSchema\nfrom loguru import logger\n\n\nT = TypeVar(\"T\")\n\n\nclass PoolMode(Enum):\n    ThreadMode = 1\n    ProcessMode = 2\n\n\n@ray.remote(concurrency_groups={\"acquire\": 1, \"release\": 10})\nclass TokenBucketWorker:\n    def __init__(self, rate_limit: int):\n        self.rate_limit = rate_limit\n        # this only used for observalability\n        self.current_count = 0\n        self._semaphore = threading.Semaphore(rate_limit)\n\n    @ray.method(concurrency_group=\"acquire\")\n    def acquire(self):\n        self._semaphore.acquire()\n        self.current_count += 1\n\n    @ray.method(concurrency_group=\"release\")\n    def release(self):\n        self._semaphore.release()\n        self.current_count -= 1\n\n    def get_current_count(self):\n        return self.current_count\n\n\nclass ExecutionWorker:\n    def __init__(self, enable_global_rate_limit=True, rate_limit=10):\n        self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None\n\n    def _init_rate_limit(self, rate_limit):\n        # TODO validation for rate_limit\n        # A Singleton Rate Limitor\n        return TokenBucketWorker.options(name=\"rate-limiter\", get_if_exists=True).remote(rate_limit)\n\n    def ping(self):\n        return True\n\n    def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T:\n        with ExitStack() as stack:\n            stack.callback(self.rate_limit_worker.release.remote)\n            ray.get(self.rate_limit_worker.acquire.remote())\n            try:\n                return fn(*fn_args, **fn_kwargs)\n            except Exception as e:\n                # TODO we should make this available to the tool caller\n                logger.warning(f\"Error when executing code: {e}\")\n\n\ndef init_execution_pool(\n    num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode\n):\n    if mode == PoolMode.ThreadMode:\n        return (\n            ray.remote(ExecutionWorker)\n            .options(max_concurrency=num_workers)\n            .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit)\n        )\n    else:\n        raise NotImplementedError(\"Process mode is not implemented yet\")\n        # return ray.util.multiprocessing.Pool(processes=num_workers)\n\n\nclass SandboxFusionTool(BaseTool):\n    \"\"\"A tool for executing the code using sanbox fusion image.\n\n    - `to_openai_function_tool_schema`: return the tool schema in OpenAI format.\n    - `create`: create a tool instance for a trajectory.\n    - `execute`: execute the tool.\n    - `calc_reward`: calculate the reward respect to tool state.\n    - `release`: release the tool instance.\n    \"\"\"\n\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n        \"\"\"\n        _tool_schema = OpenAIFunctionToolSchema.model_validate({\n            \"type\": \"function\",\n            \"function\": {\n                \"name\": \"code_interpreter\",\n                \"description\": \"A tool for execute code\",\n                \"parameters\": {\n                    \"type\": \"object\",\n                    \"properties\": {\n                        \"code\": {\n                            \"type\": \"string\",\n                            \"description\": \"code needs to be execute and grad\",\n                        },\n                    },\n                    \"required\": [\"code\"],\n                },\n            }\n        })\n        \"\"\"\n        super().__init__(config, tool_schema)\n        self._instance_dict = {}\n        # TODO: better documentation for the config\n        self.num_workers = config.get(\"num_workers\", 10)\n        self.rate_limit = config.get(\"rate_limit\", 10)\n        self.default_timeout = config.get(\"default_timeout\", 30)\n        self.default_language = config.get(\"default_language\", \"python\")\n        self.enable_global_rate_limit = config.get(\"enable_global_rate_limit\", True)\n        self.execution_pool = init_execution_pool(\n            num_workers=self.num_workers,\n            enable_global_rate_limit=self.enable_global_rate_limit,\n            rate_limit=self.rate_limit,\n            mode=PoolMode.ThreadMode,\n        )\n        self.sandbox_fusion_url = config.get(\"sandbox_fusion_url\", \"\")\n        self.memory_limit_mb = config.get(\"memory_limit_mb\", 1024)\n        if self.sandbox_fusion_url == \"\":\n            raise ValueError(\"sandbox_fusion_url is not set\")\n        log_msg = f\"Init SandboxFusionTool with config: {config}\"\n        logger.info(log_msg)\n\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n        return self.tool_schema\n\n    async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str:\n        if instance_id is None:\n            instance_id = str(uuid4())\n        self._instance_dict[instance_id] = {\n            \"response\": \"\",\n            \"ground_truth\": ground_truth,\n            \"reward\": [],\n        }\n        return instance_id\n\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:\n        code = parameters.get(\"code\", \"\")\n        timeout = parameters.get(\"timeout\", self.default_timeout)\n        language = parameters.get(\"language\", self.default_language)\n        if not isinstance(code, str):\n            code = str(code)\n\n        result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language)\n        # sandbox has no score or metrics, use Nones\n        return result, None, None\n\n    def execute_code(self, instance_id, code, timeout=30, language=\"python\"):\n        result_status, metadata = _process_single_case(\n            0, None, None, self.sandbox_fusion_url, code, timeout, self.memory_limit_mb, language\n        )\n        # we should always expect this since we don't have correct answer\n        if metadata[\"run_status\"] == \"Finished\":\n            actual_output = metadata[\"stdout\"] + metadata[\"stderr\"]\n            logger.debug(f\"actual_output from sandbox fusion: {actual_output},{instance_id}\")\n            return actual_output\n        else:\n            return \"no stdout here\"\n\n    async def calc_reward(self, instance_id: str, **kwargs) -> str:\n        return self._instance_dict[instance_id][\"reward\"]\n\n    async def release(self, instance_id: str, **kwargs) -> None:\n        del self._instance_dict[instance_id]\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/tools/schemas.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nfrom typing import Any, Literal\n\nfrom pydantic import BaseModel\n\n\nclass OpenAIFunctionPropertySchema(BaseModel):\n    \"\"\"The schema of a parameter in OpenAI format.\"\"\"\n\n    type: str\n    description: str | None = None\n    enum: list[str] | None = None\n\n\nclass OpenAIFunctionParametersSchema(BaseModel):\n    \"\"\"The schema of parameters in OpenAI format.\"\"\"\n\n    type: str\n    properties: dict[str, OpenAIFunctionPropertySchema]\n    required: list[str]\n\n\nclass OpenAIFunctionSchema(BaseModel):\n    \"\"\"The schema of a function in OpenAI format.\"\"\"\n\n    name: str\n    description: str\n    parameters: OpenAIFunctionParametersSchema\n    strict: bool = False\n\n\nclass OpenAIFunctionToolSchema(BaseModel):\n    \"\"\"The schema of a tool in OpenAI format.\"\"\"\n\n    type: str\n    function: OpenAIFunctionSchema\n\n\nclass OpenAIFunctionParsedSchema(BaseModel):\n    \"\"\"The parsed schema of a tool in OpenAI format.\"\"\"\n\n    name: str\n    arguments: str  # JSON string\n\n\nclass OpenAIFunctionCallSchema(BaseModel):\n    \"\"\"The parsed schema of a tool in OpenAI format.\"\"\"\n\n    name: str\n    arguments: dict[str, Any]\n\n    @staticmethod\n    def from_openai_function_parsed_schema(\n        parsed_schema: OpenAIFunctionParsedSchema,\n    ) -> tuple[\"OpenAIFunctionCallSchema\", bool]:\n        has_decode_error = False\n        try:\n            arguments = json.loads(parsed_schema.arguments)\n        except json.JSONDecodeError:\n            arguments = {}\n            has_decode_error = True\n        # If the arguments is not a dict, it means the arguments is not a valid JSON string\n        if not isinstance(arguments, dict):\n            arguments = {}\n            has_decode_error = True\n\n        return OpenAIFunctionCallSchema(name=parsed_schema.name, arguments=arguments), has_decode_error\n\n\nclass OpenAIFunctionToolCall(BaseModel):\n    \"\"\"The tool call in OpenAI format.\"\"\"\n\n    id: str\n    type: Literal[\"function\"] = \"function\"\n    function: OpenAIFunctionCallSchema\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/tools/search_tool.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\r\n# Copyright 2023-2024 SGLang Team\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport json\r\nimport logging\r\nimport os\r\nimport threading\r\nfrom contextlib import ExitStack\r\nfrom enum import Enum\r\nfrom typing import Any, Callable, Optional, Tuple, TypeVar\r\nfrom uuid import uuid4\r\n\r\nimport ray\r\nimport ray.actor\r\n\r\nfrom siirl.execution.rollout_flow.multiturn.tools.utils.search_r1_like_utils import perform_single_search_batch\r\n\r\nfrom .base_tool import BaseTool\r\nfrom .schemas import OpenAIFunctionToolSchema\r\n\r\nfrom loguru import logger\r\n\r\nT = TypeVar(\"T\")\r\n\r\n\r\n# Adapted from siirl/tools/sandbox_fusion_tools.py\r\nclass PoolMode(Enum):\r\n    \"\"\"Execution pool mode enumeration.\"\"\"\r\n\r\n    ThreadMode = 1\r\n    ProcessMode = 2\r\n\r\n\r\n@ray.remote(concurrency_groups={\"acquire\": 1, \"release\": 10})\r\nclass TokenBucketWorker:\r\n    \"\"\"Ray actor for rate limiting using token bucket algorithm.\"\"\"\r\n\r\n    def __init__(self, rate_limit: int):\r\n        self.rate_limit = rate_limit\r\n        self.current_count = 0  # For observability\r\n        self._semaphore = threading.Semaphore(rate_limit)\r\n\r\n    @ray.method(concurrency_group=\"acquire\")\r\n    def acquire(self):\r\n        \"\"\"Acquire a token from the bucket.\"\"\"\r\n        self._semaphore.acquire()\r\n        self.current_count += 1\r\n\r\n    @ray.method(concurrency_group=\"release\")\r\n    def release(self):\r\n        \"\"\"Release a token back to the bucket.\"\"\"\r\n        self._semaphore.release()\r\n        self.current_count -= 1\r\n\r\n    def get_current_count(self):\r\n        \"\"\"Get current number of acquired tokens.\"\"\"\r\n        return self.current_count\r\n\r\n\r\nclass SearchExecutionWorker:\r\n    \"\"\"Worker for executing search operations with optional rate limiting.\"\"\"\r\n\r\n    def __init__(self, enable_global_rate_limit=True, rate_limit=10):\r\n        self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None\r\n\r\n    def _init_rate_limit(self, rate_limit):\r\n        \"\"\"Initialize singleton rate limiter.\"\"\"\r\n        return TokenBucketWorker.options(name=\"rate-limiter\", get_if_exists=True).remote(rate_limit)\r\n\r\n    def ping(self):\r\n        \"\"\"Health check method.\"\"\"\r\n        return True\r\n\r\n    def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T:\r\n        \"\"\"Execute function with optional rate limiting.\"\"\"\r\n        if self.rate_limit_worker:\r\n            with ExitStack() as stack:\r\n                stack.callback(self.rate_limit_worker.release.remote)\r\n                ray.get(self.rate_limit_worker.acquire.remote())\r\n                try:\r\n                    return fn(*fn_args, **fn_kwargs)\r\n                except Exception as e:\r\n                    # TODO we should make this available to the tool caller\r\n                    logger.warning(f\"Error when executing search: {e}\")\r\n        else:\r\n            return fn(*fn_args, **fn_kwargs)\r\n\r\n\r\ndef init_search_execution_pool(\r\n    num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode\r\n):\r\n    \"\"\"Initialize search execution pool.\"\"\"\r\n    if mode == PoolMode.ThreadMode:\r\n        return (\r\n            ray.remote(SearchExecutionWorker)\r\n            .options(max_concurrency=num_workers)\r\n            .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit)\r\n        )\r\n    else:\r\n        raise NotImplementedError(\"Process mode is not implemented yet\")\r\n\r\n\r\nclass SearchTool(BaseTool):\r\n    \"\"\"Search tool for retrieving information using external retrieval services.\r\n\r\n    This tool provides search functionality with rate limiting and concurrent execution\r\n    support through Ray. It integrates with external retrieval services to perform\r\n    semantic search operations.\r\n\r\n    Methods:\r\n        get_openai_tool_schema: Return the tool schema in OpenAI format\r\n        create: Create a tool instance for a trajectory\r\n        execute: Execute the search tool\r\n        calc_reward: Calculate the reward with respect to tool state\r\n        release: Release the tool instance\r\n    \"\"\"\r\n\r\n    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\r\n        \"\"\"Initialize SearchTool with configuration and schema.\r\n\r\n        Args:\r\n            config: Configuration dictionary containing tool settings\r\n            tool_schema: OpenAI function tool schema definition\r\n\r\n        Example tool_schema:\r\n            {\r\n                \"type\": \"function\",\r\n                \"function\": {\r\n                    \"name\": \"search\",\r\n                    \"description\": \"Searches for relevant information based on queries.\",\r\n                    \"parameters\": {\r\n                        \"type\": \"object\",\r\n                        \"properties\": {\r\n                            \"query_list\": {\r\n                                \"type\": \"array\",\r\n                                \"items\": {\"type\": \"string\"},\r\n                                \"description\": \"List of search queries\"\r\n                            }\r\n                        },\r\n                        \"required\": [\"query_list\"]\r\n                    }\r\n                }\r\n            }\r\n        \"\"\"\r\n        super().__init__(config, tool_schema)\r\n        self._instance_dict = {}\r\n\r\n        # Worker and rate limiting configuration\r\n        self.num_workers = config.get(\"num_workers\", 120)\r\n        self.rate_limit = config.get(\"rate_limit\", 120)\r\n        self.timeout = config.get(\"timeout\", 30)\r\n\r\n        self.enable_global_rate_limit = config.get(\"enable_global_rate_limit\", True)\r\n        self.execution_pool = init_search_execution_pool(\r\n            num_workers=self.num_workers,\r\n            enable_global_rate_limit=self.enable_global_rate_limit,\r\n            rate_limit=self.rate_limit,\r\n            mode=PoolMode.ThreadMode,\r\n        )\r\n\r\n        # Retrieval service configuration\r\n        self.retrieval_service_url = config.get(\"retrieval_service_url\")\r\n        assert self.retrieval_service_url, \"Configuration must include 'retrieval_service_url'\"\r\n        self.topk = config.get(\"topk\", 3)\r\n        if self.retrieval_service_url == \"\":\r\n            raise ValueError(\"retrieval_service_url is not set\")\r\n\r\n        logger.info(f\"Initialized SearchTool with config: {config}\")\r\n\r\n    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\r\n        \"\"\"Return the OpenAI tool schema.\"\"\"\r\n        return self.tool_schema\r\n\r\n    async def create(self, instance_id: Optional[str] = None, **kwargs) -> str:\r\n        \"\"\"Create a tool instance.\r\n\r\n        Args:\r\n            instance_id: The instance id of the tool.\r\n\r\n        Returns:\r\n            The instance id of the tool.\r\n        \"\"\"\r\n        if instance_id is None:\r\n            instance_id = str(uuid4())\r\n        self._instance_dict[instance_id] = {\r\n            \"response\": \"\",\r\n            \"reward\": [],\r\n        }\r\n        return instance_id\r\n\r\n    def execute_search(self, instance_id: str, query_list: list, retrieval_service_url: str, topk: int, timeout: int):\r\n        \"\"\"Execute search operation using retrieval service.\r\n\r\n        Args:\r\n            instance_id: Tool instance ID\r\n            query_list: List of search queries\r\n            retrieval_service_url: URL of the retrieval service\r\n            topk: Number of top results to return\r\n            timeout: Request timeout in seconds\r\n\r\n        Returns:\r\n            Tuple of (result_text, metadata)\r\n        \"\"\"\r\n        result_text, metadata = perform_single_search_batch(\r\n            retrieval_service_url=retrieval_service_url,\r\n            query_list=query_list,\r\n            topk=topk,\r\n            concurrent_semaphore=None,  # Ray handles concurrency control\r\n            timeout=timeout,\r\n        )\r\n        logger.debug(f\"Search result for instance {instance_id}: {result_text}\")\r\n        return result_text, metadata\r\n\r\n    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:\r\n        \"\"\"Execute the search tool.\r\n\r\n        Args:\r\n            instance_id: The instance ID of the tool\r\n            parameters: Tool parameters containing query_list and optional timeout\r\n\r\n        Returns: tool_response, tool_reward_score, tool_metrics\r\n            tool_response: The response str of the tool.\r\n            tool_reward_score: The step reward score of the tool.\r\n            tool_metrics: The metrics of the tool.\r\n        \"\"\"\r\n        timeout = self.timeout\r\n        query_list_from_params = parameters.get(\"query_list\")\r\n\r\n        if not query_list_from_params or not isinstance(query_list_from_params, list):\r\n            error_msg = \"Error: 'query_list' is missing, empty, or not a list in parameters.\"\r\n            logger.error(f\"[SearchTool] {error_msg} Received parameters: {parameters}\")\r\n            return json.dumps({\"result\": error_msg}), 0.0, {}\r\n\r\n        # Execute search using Ray execution pool\r\n        try:\r\n            result_text, metadata = await self.execution_pool.execute.remote(\r\n                self.execute_search, instance_id, query_list_from_params, self.retrieval_service_url, self.topk, timeout\r\n            )\r\n\r\n            # Store results in instance dictionary\r\n            self._instance_dict[instance_id][\"reward\"].append(result_text.strip())\r\n\r\n            # Convert metadata to metrics\r\n            metrics = {\r\n                \"query_count\": metadata.get(\"query_count\", 0),\r\n                \"status\": metadata.get(\"status\", \"unknown\"),\r\n                \"total_results\": metadata.get(\"total_results\", 0),\r\n                \"api_request_error\": metadata.get(\"api_request_error\"),\r\n            }\r\n\r\n            return result_text, 0.0, metrics\r\n\r\n        except Exception as e:\r\n            error_result = json.dumps({\"result\": f\"Search execution failed: {e}\"})\r\n            logger.error(f\"[SearchTool] Execution failed: {e}\")\r\n            return error_result, 0.0, {\"error\": str(e)}\r\n\r\n    async def calc_reward(self, instance_id: str, **kwargs) -> str:\r\n        return self._instance_dict[instance_id][\"reward\"]\r\n\r\n    async def release(self, instance_id: str, **kwargs) -> None:\r\n        if instance_id in self._instance_dict:\r\n            del self._instance_dict[instance_id]\r\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/tools/utils/__init__.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/tools/utils/mcp_clients/McpClientManager.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport asyncio\r\nimport json\r\nimport logging\r\nfrom typing import Any\r\n\r\nfrom fastmcp import Client\r\nfrom fastmcp.client.transports import SSETransport\r\n\r\nfrom siirl.execution.rollout_flow.multiturn.tools.utils.mcp_clients.utils import TokenBucket, mcp2openai\r\n\r\nlogger = logging.getLogger(__name__)\r\n\r\n\r\nclass MCPClientManager:\r\n    rootServerName = \"mcpServers\"\r\n    initialized = False\r\n    clients = []\r\n    tool_client_mapping = {}\r\n    rate_limiter = None\r\n\r\n    async def initialize(self, config_path, rate_limit: float = 10.0):\r\n        if self.initialized:\r\n            return\r\n        \"\"\"Initialize the MCP Client Manager and start all clients\"\"\"\r\n        result = self._load_config(config_path)\r\n        servers = result[self.rootServerName]\r\n        exclude_sse_servers = {self.rootServerName: {}}\r\n        for server_name in servers.keys():\r\n            server = servers[server_name]\r\n            if \"auth_token\" in server:\r\n                transport = SSETransport(url=server[\"url\"], headers={\"Authorization\": f\"Bearer {server['auth_token']}\"})\r\n                client = Client(transport)\r\n                self.clients.append(client)\r\n            else:\r\n                exclude_sse_servers[self.rootServerName][server_name] = server\r\n\r\n        if exclude_sse_servers[self.rootServerName]:\r\n            self.clients.append(Client(exclude_sse_servers))\r\n\r\n        # Initialize rate limiter\r\n        self.rate_limiter = TokenBucket(rate_limit)\r\n        self.initialized = True\r\n\r\n    async def call_tool(self, tool_name, parameters, timeout):\r\n        # Apply rate limiting\r\n        while not self.rate_limiter.acquire():\r\n            await asyncio.sleep(0.1)\r\n\r\n        client = self.get_client_with_tool_name(tool_name)\r\n        async with client:\r\n            return await client.call_tool_mcp(tool_name, parameters)\r\n\r\n    async def fetch_tool_schemas(self, tool_selected_list: list[str]) -> list[dict]:\r\n        tool_schemas = []\r\n        for client in self.clients:\r\n            async with client:\r\n                tools = await client.list_tools_mcp()\r\n                for tool in tools.tools:\r\n                    if not tool_selected_list:\r\n                        self.tool_client_mapping[tool.name] = client\r\n                        tool_schemas.append(mcp2openai(tool))\r\n                    elif tool.name in tool_selected_list:\r\n                        self.tool_client_mapping[tool.name] = client\r\n                        tool_schemas.append(mcp2openai(tool))\r\n\r\n        return tool_schemas\r\n\r\n    def get_client_with_tool_name(self, tool_name: str):\r\n        return self.tool_client_mapping[tool_name]\r\n\r\n    def _load_config(self, file: str) -> dict[str, Any]:\r\n        try:\r\n            with open(file) as f:\r\n                return json.load(f)\r\n        except FileNotFoundError:\r\n            logger.warning(f'the \"{file}\" file was not found')\r\n        except Exception:\r\n            logger.error(f'there was an error reading the \"{file}\" file')\r\n\r\n        return {}\r\n\r\n\r\nClientManager = MCPClientManager()\r\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/tools/utils/mcp_clients/__init__.py",
    "content": ""
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/tools/utils/mcp_clients/utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport threading\nimport time\n\nfrom mcp import Tool\n\nlogger = logging.getLogger(__file__)\n\n\nclass TokenBucket:\n    def __init__(self, rate_limit: float):\n        self.rate_limit = rate_limit  # tokens per second\n        self.tokens = rate_limit\n        self.last_update = time.time()\n        self.lock = threading.Lock()\n\n    def acquire(self) -> bool:\n        with self.lock:\n            now = time.time()\n            # Add new tokens based on time elapsed\n            new_tokens = (now - self.last_update) * self.rate_limit\n            self.tokens = min(self.rate_limit, self.tokens + new_tokens)\n            self.last_update = now\n\n            if self.tokens >= 1:\n                self.tokens -= 1\n                return True\n            return False\n\n\ndef mcp2openai(mcp_tool: Tool) -> dict:\n    \"\"\"Convert a MCP Tool to an OpenAI ChatCompletionTool.\"\"\"\n    openai_format = {\n        \"type\": \"function\",\n        \"function\": {\n            \"name\": mcp_tool.name,\n            \"description\": mcp_tool.description,\n            \"parameters\": mcp_tool.inputSchema,\n            \"strict\": False,\n        },\n    }\n    if not openai_format[\"function\"][\"parameters\"].get(\"required\", None):\n        openai_format[\"function\"][\"parameters\"][\"required\"] = []\n    return openai_format\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/tools/utils/search_r1_like_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\r\n# Copyright 2023-2024 SGLang Team\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport json\r\nimport logging\r\nimport threading\r\nimport time\r\nimport traceback\r\nimport uuid\r\nfrom typing import Any, Dict, List, Optional, Tuple\r\n\r\nimport requests\r\n\r\nDEFAULT_TIMEOUT = 30  # Default search request timeout\r\nMAX_RETRIES = 10\r\nINITIAL_RETRY_DELAY = 1\r\nAPI_TIMEOUT = 10\r\n\r\nlogger = logging.getLogger(__name__)\r\n\r\n\r\ndef call_search_api(\r\n    retrieval_service_url: str,\r\n    query_list: List[str],\r\n    topk: int = 3,\r\n    return_scores: bool = True,\r\n    timeout: int = DEFAULT_TIMEOUT,\r\n) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:\r\n    \"\"\"\r\n    Calls the remote search API to perform retrieval with retry logic for various errors,\r\n    using increasing delay between retries. Logs internal calls with a unique ID.\r\n\r\n    Args:\r\n        retrieval_service_url: The URL of the retrieval service API.\r\n        query_list: List of search queries.\r\n        topk: Number of top results to return.\r\n        return_scores: Whether to return scores.\r\n        timeout: Request timeout in seconds.\r\n\r\n    Returns:\r\n        A tuple (response_json, error_message).\r\n        If successful, response_json is the API's returned JSON object, error_message is None.\r\n        If failed after retries, response_json is None, error_message contains the error information.\r\n    \"\"\"\r\n    request_id = str(uuid.uuid4())\r\n    log_prefix = f\"[Search Request ID: {request_id}] \"\r\n\r\n    payload = {\"queries\": query_list, \"topk\": topk, \"return_scores\": return_scores}\r\n\r\n    headers = {\"Content-Type\": \"application/json\", \"Accept\": \"application/json\"}\r\n\r\n    last_error = None\r\n\r\n    for attempt in range(MAX_RETRIES):\r\n        try:\r\n            logger.info(\r\n                f\"{log_prefix}Attempt {attempt + 1}/{MAX_RETRIES}: Calling search API at {retrieval_service_url}\"\r\n            )\r\n            response = requests.post(\r\n                retrieval_service_url,\r\n                headers=headers,\r\n                json=payload,\r\n                timeout=timeout,\r\n            )\r\n\r\n            # Check for Gateway Timeout (504) and other server errors for retrying\r\n            if response.status_code in [500, 502, 503, 504]:\r\n                last_error = (\r\n                    f\"{log_prefix}API Request Error: Server Error ({response.status_code}) on attempt \"\r\n                    f\"{attempt + 1}/{MAX_RETRIES}\"\r\n                )\r\n                logger.warning(last_error)\r\n                if attempt < MAX_RETRIES - 1:\r\n                    delay = INITIAL_RETRY_DELAY * (attempt + 1)\r\n                    logger.info(f\"{log_prefix}Retrying after {delay} seconds...\")\r\n                    time.sleep(delay)\r\n                continue\r\n\r\n            # Check for other HTTP errors (e.g., 4xx)\r\n            response.raise_for_status()\r\n\r\n            # If successful (status code 2xx)\r\n            logger.info(f\"{log_prefix}Search API call successful on attempt {attempt + 1}\")\r\n            return response.json(), None\r\n\r\n        except requests.exceptions.ConnectionError as e:\r\n            last_error = f\"{log_prefix}Connection Error: {e}\"\r\n            logger.warning(last_error)\r\n            if attempt < MAX_RETRIES - 1:\r\n                delay = INITIAL_RETRY_DELAY * (attempt + 1)\r\n                logger.info(f\"{log_prefix}Retrying after {delay} seconds...\")\r\n                time.sleep(delay)\r\n            continue\r\n        except requests.exceptions.Timeout as e:\r\n            last_error = f\"{log_prefix}Timeout Error: {e}\"\r\n            logger.warning(last_error)\r\n            if attempt < MAX_RETRIES - 1:\r\n                delay = INITIAL_RETRY_DELAY * (attempt + 1)\r\n                logger.info(f\"{log_prefix}Retrying after {delay} seconds...\")\r\n                time.sleep(delay)\r\n            continue\r\n        except requests.exceptions.RequestException as e:\r\n            last_error = f\"{log_prefix}API Request Error: {e}\"\r\n            break  # Exit retry loop on other request errors\r\n        except json.JSONDecodeError as e:\r\n            raw_response_text = response.text if \"response\" in locals() else \"N/A\"\r\n            last_error = f\"{log_prefix}API Response JSON Decode Error: {e}, Response: {raw_response_text[:200]}\"\r\n            break  # Exit retry loop on JSON decode errors\r\n        except Exception as e:\r\n            last_error = f\"{log_prefix}Unexpected Error: {e}\"\r\n            break  # Exit retry loop on other unexpected errors\r\n\r\n    # If loop finishes without returning success, return the last recorded error\r\n    logger.error(f\"{log_prefix}Search API call failed. Last error: {last_error}\")\r\n    return None, last_error.replace(log_prefix, \"API Call Failed: \") if last_error else \"API Call Failed after retries\"\r\n\r\n\r\ndef _passages2string(retrieval_result):\r\n    \"\"\"Convert retrieval results to formatted string.\"\"\"\r\n    format_reference = \"\"\r\n    for idx, doc_item in enumerate(retrieval_result):\r\n        content = doc_item[\"document\"][\"contents\"]\r\n        title = content.split(\"\\n\")[0]\r\n        text = \"\\n\".join(content.split(\"\\n\")[1:])\r\n        format_reference += f\"Doc {idx + 1} (Title: {title})\\n{text}\\n\\n\"\r\n    return format_reference.strip()\r\n\r\n\r\ndef perform_single_search_batch(\r\n    retrieval_service_url: str,\r\n    query_list: List[str],\r\n    topk: int = 3,\r\n    concurrent_semaphore: Optional[threading.Semaphore] = None,\r\n    timeout: int = DEFAULT_TIMEOUT,\r\n) -> Tuple[str, Dict[str, Any]]:\r\n    \"\"\"\r\n    Performs a single batch search for multiple queries (original search tool behavior).\r\n\r\n    Args:\r\n        retrieval_service_url: The URL of the retrieval service API.\r\n        query_list: List of search queries.\r\n        topk: Number of top results to return.\r\n        concurrent_semaphore: Optional semaphore for concurrency control.\r\n        timeout: Request timeout in seconds.\r\n\r\n    Returns:\r\n        A tuple (result_text, metadata).\r\n        result_text: The search result JSON string.\r\n        metadata: Metadata dictionary for the batch search.\r\n    \"\"\"\r\n    logger.info(f\"Starting batch search for {len(query_list)} queries.\")\r\n\r\n    api_response = None\r\n    error_msg = None\r\n\r\n    try:\r\n        if concurrent_semaphore:\r\n            with concurrent_semaphore:\r\n                api_response, error_msg = call_search_api(\r\n                    retrieval_service_url=retrieval_service_url,\r\n                    query_list=query_list,\r\n                    topk=topk,\r\n                    return_scores=True,\r\n                    timeout=timeout,\r\n                )\r\n        else:\r\n            api_response, error_msg = call_search_api(\r\n                retrieval_service_url=retrieval_service_url,\r\n                query_list=query_list,\r\n                topk=topk,\r\n                return_scores=True,\r\n                timeout=timeout,\r\n            )\r\n    except Exception as e:\r\n        error_msg = f\"API Request Exception during batch search: {e}\"\r\n        logger.error(f\"Batch search: {error_msg}\")\r\n        traceback.print_exc()\r\n\r\n    metadata = {\r\n        \"query_count\": len(query_list),\r\n        \"queries\": query_list,\r\n        \"api_request_error\": error_msg,\r\n        \"api_response\": None,\r\n        \"status\": \"unknown\",\r\n        \"total_results\": 0,\r\n        \"formatted_result\": None,\r\n    }\r\n\r\n    result_text = json.dumps({\"result\": \"Search request failed or timed out after retries.\"})\r\n\r\n    if error_msg:\r\n        metadata[\"status\"] = \"api_error\"\r\n        result_text = json.dumps({\"result\": f\"Search error: {error_msg}\"})\r\n        logger.error(f\"Batch search: API error occurred: {error_msg}\")\r\n    elif api_response:\r\n        logger.debug(f\"Batch search: API Response: {api_response}\")\r\n        metadata[\"api_response\"] = api_response\r\n\r\n        try:\r\n            raw_results = api_response.get(\"result\", [])\r\n            if raw_results:\r\n                pretty_results = []\r\n                total_results = 0\r\n\r\n                for retrieval in raw_results:\r\n                    formatted = _passages2string(retrieval)\r\n                    pretty_results.append(formatted)\r\n                    total_results += len(retrieval) if isinstance(retrieval, list) else 1\r\n\r\n                final_result = \"\\n---\\n\".join(pretty_results)\r\n                result_text = json.dumps({\"result\": final_result})\r\n                metadata[\"status\"] = \"success\"\r\n                metadata[\"total_results\"] = total_results\r\n                metadata[\"formatted_result\"] = final_result\r\n                logger.info(f\"Batch search: Successful, got {total_results} total results\")\r\n            else:\r\n                result_text = json.dumps({\"result\": \"No search results found.\"})\r\n                metadata[\"status\"] = \"no_results\"\r\n                metadata[\"total_results\"] = 0\r\n                logger.info(\"Batch search: No results found\")\r\n        except Exception as e:\r\n            error_msg = f\"Error processing search results: {e}\"\r\n            result_text = json.dumps({\"result\": error_msg})\r\n            metadata[\"status\"] = \"processing_error\"\r\n            logger.error(f\"Batch search: {error_msg}\")\r\n    else:\r\n        metadata[\"status\"] = \"unknown_api_state\"\r\n        result_text = json.dumps({\"result\": \"Unknown API state (no response and no error message).\"})\r\n        logger.error(\"Batch search: Unknown API state.\")\r\n\r\n    return result_text, metadata\r\n"
  },
  {
    "path": "siirl/execution/rollout_flow/multiturn/tools/utils/tool_registry.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport importlib\nimport logging\nimport os\nimport sys\nfrom enum import Enum\n\nfrom omegaconf import OmegaConf\n\nfrom siirl.execution.rollout_flow.multiturn.tools.schemas import OpenAIFunctionToolSchema\n\nfrom loguru import logger\n\n\nclass ToolType(Enum):\n    NATIVE = \"native\"\n    MCP = \"mcp\"\n\n\nasync def initialize_mcp_tool(tool_cls, tool_config) -> list:\n    from siirl.execution.rollout_flow.multiturn.tools.utils.mcp_clients.McpClientManager import ClientManager\n\n    tool_list = []\n    mcp_servers_config_path = tool_config.mcp.mcp_servers_config_path\n    tool_selected_list = tool_config.mcp.tool_selected_list if \"tool_selected_list\" in tool_config.mcp else None\n    await ClientManager.initialize(mcp_servers_config_path, tool_config.config.rate_limit)\n    # Wait for MCP client to be ready\n    max_retries = 10\n    retry_interval = 2  # seconds\n    for i in range(max_retries):\n        tool_schemas = await ClientManager.fetch_tool_schemas(tool_selected_list)\n        if tool_schemas:\n            break\n        if i < max_retries - 1:\n            logger.debug(f\"Waiting for MCP client to be ready, attempt {i + 1}/{max_retries}\")\n            await asyncio.sleep(retry_interval)\n    else:\n        raise RuntimeError(\"Failed to initialize MCP tools after maximum retries\")\n    # mcp registry\n    assert len(tool_schemas), \"mcp tool is empty\"\n    for tool_schema_dict in tool_schemas:\n        logger.debug(f\"tool_schema_dict: {tool_schema_dict}\")\n        tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict)\n        tool = tool_cls(\n            config=OmegaConf.to_container(tool_config.config, resolve=True),\n            tool_schema=tool_schema,\n        )\n        tool_list.append(tool)\n    return tool_list\n\n\ndef get_tool_class(cls_name):\n    module_name, class_name = cls_name.rsplit(\".\", 1)\n    if module_name not in sys.modules:\n        spec = importlib.util.find_spec(module_name)\n        module = importlib.util.module_from_spec(spec)\n        sys.modules[module_name] = module\n        spec.loader.exec_module(module)\n    else:\n        module = sys.modules[module_name]\n\n    tool_cls = getattr(module, class_name)\n    return tool_cls\n\n\ndef initialize_tools_from_config(tools_config_file):\n    tools_config = OmegaConf.load(tools_config_file)\n    tool_list = []\n    for tool_config in tools_config.tools:\n        cls_name = tool_config.class_name\n        tool_type = ToolType(tool_config.config.type)\n        tool_cls = get_tool_class(cls_name)\n\n        match tool_type:\n            case ToolType.NATIVE:\n                if tool_config.get(\"tool_schema\", None) is None:\n                    tool_schema = None\n                else:\n                    tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True)\n                    tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict)\n                tool = tool_cls(\n                    config=OmegaConf.to_container(tool_config.config, resolve=True),\n                    tool_schema=tool_schema,\n                )\n                tool_list.append(tool)\n            case ToolType.MCP:\n                loop = asyncio.get_event_loop()\n                mcp_tools = loop.run_until_complete(initialize_mcp_tool(tool_cls, tool_config))\n                tool_list.extend(mcp_tools)\n            case _:\n                raise NotImplementedError\n    return tool_list\n"
  },
  {
    "path": "siirl/execution/scheduler/__init__.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "siirl/execution/scheduler/enums.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom enum import Enum\n\n\nclass AdvantageEstimator(str, Enum):\n    \"\"\"\n    Using an enumeration class to avoid spelling errors in adv_estimator\n    \"\"\"\n\n    GAE = \"gae\"\n    GRPO = \"grpo\"\n    CPGD = \"cpgd\"\n    GAE_MARFT = \"gae_marft\"\n\n\nclass WorkflowType(str, Enum):\n    DEFAULT = \"default\"\n    DAPO = \"dapo\"\n    EMBODIED = \"embodied\"\n\n\nclass Role(Enum):\n    \"\"\"\n    To create more roles dynamically, you can subclass Role and add new members\n    \"\"\"\n\n    Actor = 0\n    Rollout = 1\n    ActorRollout = 2\n    Critic = 3\n    RefPolicy = 4\n    RewardModel = 5\n    ActorRolloutRef = 6\n\n\nclass AlgorithmType(Enum):\n    \"\"\"\n    Enum to represent different algorithm types.\n    \"\"\"\n\n    PPO = \"ppo\"\n    GRPO = \"grpo\"\n    DAPO = \"dapo\"\n    REINFORCE_PLUS_PLUS = \"reinforce_plus_plus\"\n    REMAX = \"remax\"\n    RLOO = \"rloo\"\n    OPO = \"opo\"\n    GRPO_PASSK = \"grpo_passk\"\n    CPGD = \"cpgd\"\n    REINFORCE_PLUS_PLUS_BASELINE = \"reinforce_plus_plus_baseline\"\n"
  },
  {
    "path": "siirl/execution/scheduler/graph_updater.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Dict, Tuple, Type, Any, Optional\nfrom dataclasses import asdict, is_dataclass\n\nfrom dacite import Config as DaciteConfig, from_dict\nfrom loguru import logger\nfrom omegaconf import DictConfig, OmegaConf\n\nfrom siirl.execution.dag import NodeRole, NodeType, TaskGraph\nfrom siirl.params import ActorRolloutRefArguments, CriticArguments, RewardModelArguments, SiiRLArguments\n\nfrom typing import TYPE_CHECKING\n\nif TYPE_CHECKING:\n    from siirl.execution.dag import TaskGraph\nfrom siirl.params import log_dict_formatted\n\nNODE_ID = \"_node_id_\"\nINTERN_CONFIG = \"intern_config\"\n\n\ndef unflatten_dict_with_omegaconf(flat_dict: Dict[str, Any]) -> Dict[str, Any]:\n    \"\"\"\n    Unflattens a flat dictionary with dot-separated keys into a nested dictionary using OmegaConf.\n\n    Args:\n        flat_dict: A dictionary where keys might be dot-separated (e.g., 'model.name').\n\n    Returns:\n        A nested dictionary.\n    \"\"\"\n    if not flat_dict:\n        return {}\n    config = OmegaConf.create()\n    for key, value in flat_dict.items():\n        try:\n            OmegaConf.update(config, key, value, merge=True, force_add=True)\n        except Exception as e:\n            logger.error(f\"OmegaConf.update failed for key='{key}', value='{value}': {e}\")\n            raise\n    return OmegaConf.to_container(config, resolve=True, throw_on_missing=False)\n\n\ndef update_task_graph_node_configs(workerflow_taskgraph: TaskGraph, basic_common_config: \"SiiRLArguments\") -> TaskGraph:\n    \"\"\"\n    Updates node configurations by merging global defaults with node-specific overrides,\n    and stores the resulting configuration as both a dictionary and a dataclass instance.\n\n    Args:\n        workerflow_taskgraph: The TaskGraph whose nodes will be updated.\n        basic_common_config: The global SiiRLArguments with default settings.\n\n    Returns:\n        The updated TaskGraph.\n    \"\"\"\n    logger.info(\"Starting update of TaskGraph node configurations (using OmegaConf and Dacite)...\")\n    workerflow_taskgraph.build_adjacency_lists()\n\n    node_role_config_map: Dict[NodeRole, Tuple[str, Type]] = {\n        NodeRole.ACTOR: (\"actor_rollout_ref\", ActorRolloutRefArguments),\n        NodeRole.ROLLOUT: (\"actor_rollout_ref\", ActorRolloutRefArguments),\n        NodeRole.REFERENCE: (\"actor_rollout_ref\", ActorRolloutRefArguments),\n        NodeRole.CRITIC: (\"critic\", CriticArguments),\n        NodeRole.REWARD: (\"reward_model\", RewardModelArguments),\n    }\n\n    for node_id, node in workerflow_taskgraph.nodes.items():\n        if node.node_type not in [NodeType.MODEL_INFERENCE, NodeType.MODEL_TRAIN]:\n            logger.debug(f\"Node '{node.node_id}' of type {node.node_type} skipped for config update.\")\n            continue\n\n        original_node_config_flat = node.config or {}\n        original_node_config_dict = unflatten_dict_with_omegaconf(original_node_config_flat)\n\n        if NODE_ID in original_node_config_dict:\n            del original_node_config_dict[NODE_ID]\n        node_specific_omega_conf = OmegaConf.create(original_node_config_dict)\n\n        if node.node_role in node_role_config_map:\n            default_config_attr_name, target_dataclass_type = node_role_config_map[node.node_role]\n            default_config_branch_instance = getattr(basic_common_config, default_config_attr_name, None)\n\n            merged_omega_conf: Optional[DictConfig] = None\n\n            if default_config_branch_instance is None:\n                logger.warning(f\"Global default config '{default_config_attr_name}' not in basic_common_config for node '{node.node_id}'. Using only node-specific config.\")\n                merged_omega_conf = node_specific_omega_conf\n            else:\n                default_config_branch_dict = asdict(default_config_branch_instance)\n                default_config_branch_omega_base = OmegaConf.create(default_config_branch_dict)\n\n                if not isinstance(default_config_branch_omega_base, DictConfig):\n                    logger.error(f\"Global config for '{default_config_attr_name}' is not a DictConfig. Cannot merge. Using only node-specific config for node '{node.node_id}'.\")\n                    merged_omega_conf = node_specific_omega_conf\n                else:\n                    merged_omega_conf = OmegaConf.merge(default_config_branch_omega_base.copy(), node_specific_omega_conf)\n\n            merged_config_dict = OmegaConf.to_container(merged_omega_conf, resolve=True, throw_on_missing=False)\n            if not isinstance(merged_config_dict, dict):\n                raise ValueError(f\"Merged config for node '{node.node_id}' is not a dictionary.\")\n\n            try:\n                # Convert the merged dictionary back into a validated dataclass instance\n                merged_dataclass_instance = from_dict(data_class=target_dataclass_type, data=merged_config_dict, config=DaciteConfig(check_types=False))\n            except Exception as e:\n                logger.error(f\"Dacite conversion to '{target_dataclass_type.__name__}' failed for node '{node.node_id}': {e}\")\n                raise\n            node.config = {INTERN_CONFIG: merged_dataclass_instance, NODE_ID: node.node_id}\n\n        else:\n            logger.warning(f\"Node '{node.node_id}' ({node.node_role}) has an unmapped role. Using its unflattened original configuration without creating a dataclass instance.\")\n            node.config = original_node_config_dict\n\n    logger.info(\"TaskGraph node configuration update complete.\")\n    return workerflow_taskgraph\n\n\ndef display_node_config(workerflow_taskgraph: TaskGraph) -> None:\n    \"\"\"\n    Prints the configuration for each node.\n    This version is adapted for when node.config primarily holds a dataclass instance.\n    \"\"\"\n    if not isinstance(workerflow_taskgraph, TaskGraph):\n        logger.error(\"Error: Input must be a TaskGraph object.\")\n        return\n\n    if not workerflow_taskgraph.nodes:\n        logger.warning(f\"Graph '{workerflow_taskgraph.graph_id}' has no nodes.\")\n        return\n\n    logger.debug(f\"Displaying configurations for all nodes in graph '{workerflow_taskgraph.graph_id}':\")\n\n    for node_id, node in workerflow_taskgraph.nodes.items():\n        if not isinstance(node.config, dict):\n            logger.warning(f\"Node '{node_id}' config is not a dictionary. Skipping.\")\n            continue\n\n        dataclass_obj = node.config.get(INTERN_CONFIG)\n\n        if dataclass_obj and is_dataclass(dataclass_obj):\n            config_for_display = asdict(dataclass_obj)\n\n            # Include the node ID in the displayed configuration\n            if NODE_ID in node.config:\n                config_for_display[NODE_ID] = node.config[NODE_ID]\n\n            log_dict_formatted(config_for_display, title=f\"Node: {node_id} Configuration Details\", log_level=\"debug\")\n        else:\n            # If the config is not a dataclass, log the raw dictionary\n            logger.warning(f\"Node '{node_id}' does not contain a valid dataclass in '{INTERN_CONFIG}'.\")\n            log_dict_formatted(node.config, title=f\"Node: {node_id} Raw Configuration Details\", log_level=\"debug\")\n"
  },
  {
    "path": "siirl/execution/scheduler/launch.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\nfrom typing import Any, Dict, List\nimport ray\nfrom loguru import logger\n\nfrom siirl.execution.dag import Node, NodeRole, NodeType, TaskGraph\n\nfrom siirl.params import ActorRolloutRefArguments, ActorArguments, RefArguments, RolloutArguments, CriticArguments, RewardModelArguments, SiiRLArguments\nfrom siirl.execution.scheduler.enums import AdvantageEstimator\n\nfrom .process_group_manager import ProcessGroupManager\nfrom .ray_actor_manager import RayActorManager\nfrom .graph_updater import INTERN_CONFIG\nfrom .resource_manager import ResourcePoolManager\n\n\nclass RayTrainer:\n    \"\"\"\n    The main orchestrator for a distributed training session using Ray.\n\n    This class is responsible for:\n    1.  Validating the configurations for all components in the task graphs.\n    2.  Managing hardware resources (GPUs) across nodes.\n    3.  Initializing and managing the lifecycle of Ray actors (DAGWorkers).\n    4.  Starting the training process and monitoring its execution until completion or failure.\n    \"\"\"\n\n    def __init__(self, config: SiiRLArguments, process_group_manager: ProcessGroupManager, rank_taskgraph_mapping: Dict[int, \"TaskGraph\"], unique_graphs_map: Dict[str, \"TaskGraph\"], data_coordinator_handle: \"ray.actor.ActorHandle\", metric_worker_handle: \"ray.actor.ActorHandle\", device_name=\"cuda\"):\n        \"\"\"\n        Initializes the RayTrainer.\n\n        Args:\n            config: The main SiiRLArguments object containing all configuration parameters.\n            process_group_manager: Manages communication groups for distributed training.\n            rank_taskgraph_mapping: A mapping from a global rank to its assigned TaskGraph.\n            unique_graphs_map: A mapping of unique graph IDs to their TaskGraph objects.\n            data_coordinator_handle: The Ray actor handle for the central DataCoordinator.\n            metric_worker_handle: The Ray actor handle for the Central Metric Worker.\n        \"\"\"\n        # Store essential configuration and management objects.\n        self.base_config = config\n        self.process_group_manager = process_group_manager\n        self.rank_taskgraph_mapping = rank_taskgraph_mapping\n        self.data_coordinator_handle = data_coordinator_handle\n        self.metric_worker_handle = metric_worker_handle\n        self.unique_graphs_map = unique_graphs_map\n\n        # Calculate the total number of GPUs available for the training job.\n        self.total_gpu = self.base_config.trainer.n_gpus_per_node * self.base_config.trainer.nnodes\n\n        # Determine whether a critic model is needed based on the chosen advantage estimator algorithm.\n        # GAE requires a critic for value estimation. Other listed methods do not.\n        if self.base_config.algorithm.adv_estimator == AdvantageEstimator.GAE:\n            self.use_critic = True\n        elif self.base_config.algorithm.adv_estimator in [\n            AdvantageEstimator.GRPO,\n            AdvantageEstimator.CPGD,\n            AdvantageEstimator.GAE_MARFT,\n        ]:\n            self.use_critic = False\n        else:\n            # If the algorithm is not recognized, raise an error.\n            raise NotImplementedError\n\n        # --- Create resource manager ---\n        # Define the specification for the global resource pool, typically GPUs per node.\n        self.global_pool_id = \"global_resource_pool\"\n        resource_pool_spec = {\n            self.global_pool_id: [self.base_config.trainer.n_gpus_per_node] * self.base_config.trainer.nnodes,\n        }\n        # Instantiate the manager to oversee the allocation of these resources.\n        self.resource_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec)\n        # The actor manager is initialized later in the `init_workers` method.\n        self.ray_actor_manager = None\n        self.device_name = device_name\n\n        # Perform a comprehensive validation of all configurations upon initialization.\n        self._validate_config()\n\n    def _check_mutually_exclusive(self, component_config: Dict[str, Any], param_name: str, param_per_gpu_name: str, component_id_str: str) -> None:\n        \"\"\"\n        A helper function to validate that only one of two mutually exclusive batch size parameters is set.\n\n        It enforces that users set the 'per_gpu' version of a parameter and not the deprecated total version.\n\n        Args:\n            component_config: The configuration dictionary for a specific component.\n            param_name: The name of the deprecated total batch size parameter (e.g., \"ppo_micro_batch_size\").\n            param_per_gpu_name: The name of the preferred per-GPU batch size parameter (e.g., \"ppo_micro_batch_size_per_gpu\").\n            component_id_str: A string identifying the component for clear error messages (e.g., \"Node 1 (Actor)\").\n        \"\"\"\n        # Get both parameter values from the config dict.\n        mbs = component_config.get(param_name, None)\n        mbs_per_gpu = component_config.get(param_per_gpu_name, None)\n\n        # Fail if neither parameter is provided.\n        if mbs is None and mbs_per_gpu is None:\n            raise ValueError(f\"[{component_id_str}] Please set at least one of '{param_name}' or '{param_per_gpu_name}'.\")\n\n        # Fail if both are provided, guiding the user to use the preferred 'per_gpu' parameter.\n        if mbs is not None and mbs_per_gpu is not None:\n            raise ValueError(f\"[{component_id_str}] You have set both '{param_name}' AND '{param_per_gpu_name}'. Please remove '{param_name}' because only '{param_per_gpu_name}' is supported (the former is deprecated).\")\n\n    def validate_actor_config(self, node: Node, actor_conf: ActorArguments, use_remove_padding: bool = False) -> None:\n        \"\"\"Validates configuration parameters specific to an Actor training node.\"\"\"\n        logger.debug(f\"Validating Actor specific configurations for Node: {node.node_id} using provided actor_conf\")\n\n        # Check if KL is used both as a reward penalty and a loss term, which is a valid but notable configuration.\n        if self.base_config.algorithm.use_kl_in_reward and actor_conf.use_kl_loss:\n            logger.info(f\"Node {node.node_id} (Actor): Both in-reward KL and KL loss are enabled for this actor configuration.\")\n\n        # Extract relevant actor parameters for validation.\n        ppo_mini_batch_size = actor_conf.ppo_mini_batch_size\n        ppo_micro_batch_size_per_gpu = actor_conf.ppo_micro_batch_size_per_gpu\n        ppo_micro_batch_size = actor_conf.ppo_micro_batch_size\n        sp_size = actor_conf.ulysses_sequence_parallel_size\n        loss_agg_mode = actor_conf.loss_agg_mode\n        strategy = actor_conf.strategy\n        use_dynamic_bsz = actor_conf.use_dynamic_bsz\n\n        # Perform batch size validations if not using a dynamic batch size.\n        if not use_dynamic_bsz:\n            self._check_mutually_exclusive(actor_conf.to_dict(), \"ppo_micro_batch_size\", \"ppo_micro_batch_size_per_gpu\", f\"Node {node.node_id} (Actor)\")\n            # Ensure the global training batch size is at least as large as the PPO mini-batch size.\n            assert self.base_config.data.train_batch_size * self.base_config.actor_rollout_ref.rollout.n >= ppo_mini_batch_size, f\"Node {node.node_id} (Actor): train_batch_size ({self.base_config.data.train_batch_size}) must be >= ppo_mini_batch_size ({ppo_mini_batch_size})\"\n            # If micro-batch size is set, perform further divisibility checks.\n            if ppo_micro_batch_size is not None:\n                component_total_mbs = ppo_micro_batch_size_per_gpu * self.total_gpu\n                assert ppo_mini_batch_size % ppo_micro_batch_size == 0, f\"Node {node.node_id} (Actor): ppo_mini_batch_size ({ppo_mini_batch_size}) must be divisible by component_total_mbs ({component_total_mbs}).\"\n                # This assertion seems to have a typo, but keeping logic as-is. It compares total vs per_gpu * sp.\n                assert ppo_micro_batch_size * sp_size >= self.total_gpu, f\"Node {node.node_id} (Actor): ppo_micro_batch_size_per_gpu * SP size ({ppo_micro_batch_size_per_gpu} * {sp_size}) must be >= {self.total_gpu}\"\n\n        # Ensure the loss aggregation mode is one of the supported values.\n        assert loss_agg_mode in [\"token-mean\", \"seq-mean-token-sum\", \"seq-mean-token-mean\", \"seq-mean-token-sum-norm\"], f\"Node {node.node_id} (Actor): Invalid loss_agg_mode: {loss_agg_mode}\"\n\n        # For FSDP with sequence parallelism, padding must be removed to avoid hangs.\n        if strategy == \"fsdp\" and sp_size > 1:\n            assert use_remove_padding, f\"Node {node.node_id} (Actor): When using SP (>1) with FSDP, enable `use_remove_padding` in the relevant model config.\"\n\n    def validate_reference_config(self, node: Node, reference_conf: RefArguments, use_remove_padding: bool = False) -> None:\n        \"\"\"Validates configuration parameters specific to a Reference Policy inference node.\"\"\"\n        logger.debug(f\"Validating Reference Policy specific configurations for Node: {node.node_id}\")\n        log_prob_use_dynamic_bsz = reference_conf.log_prob_use_dynamic_bsz\n        strategy = reference_conf.strategy\n        ulysses_sequence_parallel_size = reference_conf.ulysses_sequence_parallel_size\n\n        # Validate micro batch size settings if not using dynamic batching.\n        if not log_prob_use_dynamic_bsz:\n            self._check_mutually_exclusive(reference_conf.to_dict(), \"log_prob_micro_batch_size\", \"log_prob_micro_batch_size_per_gpu\", f\"Node {node.node_id} (Reference)\")\n\n        # For FSDP with sequence parallelism, padding must be removed.\n        if strategy == \"fsdp\" and ulysses_sequence_parallel_size > 1:\n            assert use_remove_padding, f\"Node {node.node_id} (Reference): When using SP (>1) with FSDP, enable `use_remove_padding` in relevant model config.\"\n\n    def validate_rollout_config(self, node: Node, rollout_conf: RolloutArguments, use_remove_padding: bool = False):\n        \"\"\"Validates configuration parameters specific to a Rollout (generation) node.\"\"\"\n        logger.debug(f\"Validating Rollout specific configurations for Node: {node.node_id}\")\n\n        # Validate micro batch size for log-probability calculations if not dynamic.\n        log_prob_use_dynamic_bsz = rollout_conf.log_prob_use_dynamic_bsz\n        if not log_prob_use_dynamic_bsz:\n            self._check_mutually_exclusive(rollout_conf.to_dict(), \"log_prob_micro_batch_size\", \"log_prob_micro_batch_size_per_gpu\", f\"Node {node.node_id} (Rollout)\")\n\n        # Extract generation parameters for validation.\n        do_sample = rollout_conf.val_kwargs.do_sample\n        multi_turn_enable = rollout_conf.multi_turn.enable\n        temperature = rollout_conf.temperature\n        tool_config_path = rollout_conf.multi_turn.tool_config_path\n\n        # If sampling is enabled, temperature must be positive.\n        if do_sample:\n            assert temperature > 0, f\"Node {node.node_id} (Rollout): validation gen temperature > 0 for do_sample.\"\n\n        # If multi-turn rollouts (e.g., with tools) are enabled, a tool config must be provided.\n        if multi_turn_enable:\n            assert tool_config_path is not None, f\"Node {node.node_id} (Rollout): tool_config_path required for multi_turn.\"\n            # Check if the algorithm is compatible with multi-turn rollouts.\n            assert self.base_config.algorithm.adv_estimator in [AdvantageEstimator.GRPO], f\"only GRPO is tested for multi-turn with tool\"\n\n    def validate_critic_config(self, node: Node, critic_conf: CriticArguments, use_remove_padding: bool = False):\n        \"\"\"Validates configuration parameters specific to a Critic training node.\"\"\"\n        logger.debug(f\"Validating Critic specific configurations for Node: {node.node_id}\")\n\n        # Extract critic parameters.\n        use_dynamic_bsz = critic_conf.use_dynamic_bsz\n        ppo_mini_batch_size = critic_conf.ppo_mini_batch_size\n        ppo_micro_batch_size_per_gpu = critic_conf.ppo_micro_batch_size_per_gpu\n        ppo_micro_batch_size = critic_conf.ppo_micro_batch_size\n        sp_size = critic_conf.ulysses_sequence_parallel_size\n        strategy = critic_conf.strategy\n\n        # Validate batch sizes if not dynamic.\n        if not use_dynamic_bsz:\n            self._check_mutually_exclusive(critic_conf.to_dict(), \"ppo_micro_batch_size\", \"ppo_micro_batch_size_per_gpu\", f\"Node {node.node_id} (Critic)\")\n            assert self.base_config.data.train_batch_size >= ppo_mini_batch_size\n            if ppo_micro_batch_size is not None:\n                effective_mbs_per_gpu = ppo_micro_batch_size_per_gpu\n                assert ppo_mini_batch_size % ppo_micro_batch_size == 0\n                assert ppo_micro_batch_size * sp_size >= 1\n\n        # For FSDP with sequence parallelism, padding must be removed.\n        if strategy == \"fsdp\" and sp_size > 1:\n            assert use_remove_padding, f\"Node {node.node_id} (Critic): When using SP (>1) with FSDP, enable `use_remove_padding` in critic.model.\"\n\n    def validate_reward_model_config(self, node: Node, reward_model_conf: RewardModelArguments, use_remove_padding: bool = False):\n        \"\"\"Validates configuration parameters specific to a Reward Model training node.\"\"\"\n        logger.debug(f\"Validating Reward Model specific configurations for Node: {node.node_id}\")\n        use_dynamic_bsz = reward_model_conf.use_dynamic_bsz\n        # Validate micro batch size settings if not using dynamic batching.\n        if not use_dynamic_bsz:\n            self._check_mutually_exclusive(reward_model_conf.to_dict(), \"micro_batch_size\", \"micro_batch_size_per_gpu\", f\"Node {node.node_id} (RewardModel)\")\n\n    def validate_configurations_for_task_graph(self, task_graph: TaskGraph) -> None:\n        \"\"\"\n        Iterates through all nodes in a task graph and dispatches to the appropriate validation function.\n\n        Args:\n            task_graph: The TaskGraph object to validate.\n        \"\"\"\n        logger.info(f\"Starting configuration validation for TaskGraph: {task_graph.graph_id}\")\n\n        # Loop over each node in the graph.\n        for node_id, node in task_graph.nodes.items():\n            logger.debug(f\"Processing Node ID: {node.node_id}, Type: {node.node_type.value}, Role: {node.node_role.value}\")\n\n            # Initialize variables for the dispatcher.\n            node_specific_config: Any = None\n            validator_function = None\n            component_name_for_logging = \"\"\n            use_remove_padding = False\n            intern_config = None\n            # The actual component-specific config is stored in a special key.\n            if INTERN_CONFIG in node.config:\n                intern_config = node.config[INTERN_CONFIG]\n\n            # Based on the node's type and role, select the correct config object and validator function.\n            if node.node_type == NodeType.MODEL_TRAIN and node.node_role == NodeRole.ACTOR:\n                assert isinstance(intern_config, ActorRolloutRefArguments), f\"Node {node_id} intern config illegal\"\n                # Calculate the effective batch size considering the number of rollouts per sample.\n                real_train_batch_size = self.base_config.data.train_batch_size * intern_config.rollout.n\n                assert real_train_batch_size % self.total_gpu == 0, f\"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({self.total_gpu}).\"\n                node_specific_config = intern_config.actor\n                validator_function = self.validate_actor_config\n                use_remove_padding = intern_config.model.use_remove_padding\n                component_name_for_logging = \"Actor\"\n            elif node.node_type == NodeType.MODEL_INFERENCE and node.node_role == NodeRole.ROLLOUT:\n                assert isinstance(intern_config, ActorRolloutRefArguments), f\"Node {node_id} intern config illegal\"\n                node_specific_config = intern_config.rollout\n                validator_function = self.validate_rollout_config\n                component_name_for_logging = \"Rollout\"\n            elif node.node_type == NodeType.MODEL_TRAIN and node.node_role == NodeRole.CRITIC:\n                assert isinstance(intern_config, CriticArguments), f\"Node {node_id} intern config illegal\"\n                if self.base_config.critic:\n                    node_specific_config = intern_config\n                    validator_function = self.validate_critic_config\n                    component_name_for_logging = \"Critic\"\n                    use_remove_padding = intern_config.model.use_remove_padding\n            elif node.node_type == NodeType.MODEL_TRAIN and node.node_role == NodeRole.REWARD:\n                assert isinstance(intern_config, RewardModelArguments), f\"Node {node_id} intern config illegal\"\n                if self.base_config.reward_model.enable:\n                    node_specific_config = intern_config\n                    validator_function = self.validate_reward_model_config\n                    component_name_for_logging = \"RewardModel\"\n                    use_remove_padding = intern_config.model.use_remove_padding\n            elif node.node_type == NodeType.MODEL_TRAIN and node.node_role == NodeRole.REFERENCE:\n                assert isinstance(intern_config, ActorRolloutRefArguments), f\"Node {node_id} intern config illegal\"\n                node_specific_config = intern_config.ref\n                validator_function = self.validate_reference_config\n                component_name_for_logging = \"ReferencePolicy\"\n                use_remove_padding = intern_config.ref.use_remove_padding\n\n            # If a validator was found for the node, execute it.\n            if validator_function and node_specific_config is not None:\n                try:\n                    logger.debug(f\"Running {component_name_for_logging} validation for Node: {node.node_id}\")\n                    # Pass the node, its specific config, and padding flag to the validator.\n                    validator_function(node, node_specific_config, use_remove_padding)\n                except (AssertionError, ValueError) as e:\n                    # If validation fails, log a fatal error and re-raise to halt execution.\n                    logger.error(f\"Configuration validation FAILED for Node {node.node_id} ({component_name_for_logging}): {e}\")\n                    raise\n            elif validator_function and node_specific_config is None:\n                # This case handles when a node should be validated (e.g., Critic) but is disabled in the main config.\n                logger.warning(f\"Node {node.node_id} ({node.node_type.value}, {node.node_role.value}) mapped to {component_name_for_logging} validator, but its config section was not found or component not enabled. Skipping specialized validation.\")\n            else:\n                # For nodes that do not require specialized validation (e.g., data nodes).\n                logger.trace(f\"No specialized validator or config section for Node {node.node_id} ({node.node_type.value}, {node.node_role.value}).\")\n\n        logger.info(f\"All configuration checks passed successfully for TaskGraph: {task_graph.graph_id}!\")\n\n    def _validate_config(self):\n        \"\"\"Entry point for configuration validation. Validates all unique task graphs.\"\"\"\n        for graph_id, task_graph in self.unique_graphs_map.items():\n            self.validate_configurations_for_task_graph(task_graph)\n\n    def init_workers(self):\n        \"\"\"Initializes the resources and Ray actors required for training.\"\"\"\n        # Step 1: Create the resource pool based on the spec defined in __init__.\n        self.resource_manager.create_resource_pool()\n        # Step 2: Create the RayActorManager, which will be responsible for creating and managing the actual DAGWorker actors.\n        ray_actor_manager_kwargs = {\"ray_wait_register_center_timeout\": self.base_config.trainer.ray_wait_register_center_timeout}\n\n        self.ray_actor_manager = RayActorManager(\n            resource_pool=self.resource_manager.get_resource_pool(self.global_pool_id),\n            base_config=self.base_config,\n            process_manager=self.process_group_manager,\n            rank_taskgraph_mapping=self.rank_taskgraph_mapping,\n            data_coordinator_handle=self.data_coordinator_handle,\n            metric_worker_handle=self.metric_worker_handle,\n            device_name=self.device_name,\n            **ray_actor_manager_kwargs,\n        )\n\n    def start_workers(self):\n        \"\"\"\n        Starts all DAGWorkers and enters a monitoring loop that waits for them to complete.\n        This method handles progress logging and robustly detects and reports actor failures.\n        \"\"\"\n        logger.success(\"create workers finished, try start training\")\n        # 1. Asynchronously start the main task (`execute_task_graph`) on all workers.\n        # This returns a list of \"futures\", which are placeholders for the eventual results.\n        work_futures = self.ray_actor_manager.map_async(method_name=\"execute_task_graph\")\n\n        start_time = time.time()\n        num_workers = len(self.ray_actor_manager.workers)\n\n        # Create a mapping from a future to its worker's name for easier logging upon completion or failure.\n        future_to_worker_name = {future: name for future, name in zip(work_futures, self.ray_actor_manager.worker_names)}\n\n        # Create a copy of the futures list to track which workers are still running.\n        remaining_futures = work_futures.copy()\n\n        # Loop until all workers have completed their tasks.\n        while remaining_futures:\n            try:\n                # 2. Wait for ANY of the remaining tasks to complete.\n                # Use a timeout so the loop doesn't block indefinitely, allowing for periodic logging.\n                ready_futures, remaining_futures = ray.wait(\n                    remaining_futures,\n                    num_returns=1,  # Return as soon as one worker finishes.\n                    timeout=60.0,  # Wait for up to 60 seconds.\n                )\n\n                # If the wait timed out and no futures are ready, it means workers are still running.\n                # Log progress and continue to the next iteration of the while loop.\n                if not ready_futures:\n                    elapsed_time = time.time() - start_time\n                    finished_count = num_workers - len(remaining_futures)\n                    logger.info(f\"INFO: Training for {elapsed_time:.0f} seconds... {finished_count}/{num_workers} workers have finished.\")\n                    continue\n\n                # 3. Process the futures that are now ready (i.e., workers that have finished).\n                for future in ready_futures:\n                    worker_name = future_to_worker_name[future]\n                    try:\n                        # `ray.get()` retrieves the result of the future.\n                        # CRITICAL: If the remote actor task failed with an exception,\n                        # `ray.get()` will re-raise that exception here, allowing us to catch it.\n                        result = ray.get(future)\n                        logger.success(f\"Worker {worker_name} has finished its task graph. Result: {result}\")\n\n                    # Specifically catch the case where a Ray actor process has died.\n                    except ray.exceptions.ActorDiedError:\n                        logger.error(f\"FATAL: Worker {worker_name} died unexpectedly during its task. Halting execution.\")\n                        # Re-raise the exception to stop the entire training job.\n                        # A dead worker is a critical failure that cannot be recovered from.\n                        raise\n                    # Catch any other exception that the worker might have thrown.\n                    except Exception as e:\n                        logger.error(f\"FATAL: Worker {worker_name} failed with an exception: {e}\")\n                        # Re-raise to halt the training job.\n                        raise\n\n            except Exception as e:\n                # Catch unexpected errors in the monitoring loop itself.\n                logger.error(f\"An unexpected error occurred during worker monitoring: {e}\")\n                raise\n\n        # This point is reached only if all workers complete successfully.\n        elapsed_time = time.time() - start_time\n        logger.success(f\"All {num_workers} DAGWorkers have successfully finished their task graphs. Total cost: {elapsed_time:.2f}s\")\n        return\n"
  },
  {
    "path": "siirl/execution/scheduler/process_group_manager.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport collections\nfrom typing import Any, Dict, List, Optional, Set, Tuple\n\nfrom loguru import logger\nfrom siirl.execution.dag import NodeType, TaskGraph\n\n\nclass ProcessGroupManager:\n    \"\"\"\n    Manages the creation and assignment of process groups for distributed training.\n\n    This class analyzes the topology of task graphs assigned to different workers (ranks)\n    to determine which ranks need to communicate. It then defines process groups\n    based on these communication patterns and provides methods to query these\n    configurations.\n\n    Attributes:\n        total_num_workers (int): The total number of workers.\n        ranks_taskgraph_mapping (Dict[int, Optional['TaskGraph']]): A mapping\n            from a worker's rank to its assigned TaskGraph.\n        relevant_node_types (Set[NodeType]): The set of node types to consider\n            when forming process groups.\n        process_group_spec (Dict[str, List[int]]): A mapping from a generated\n            process group name to the list of ranks it contains.\n        node_process_group_mapping (Dict[str, str]): A mapping from a node's ID\n            to the name of the process group it belongs to.\n    \"\"\"\n\n    def __init__(\n        self,\n        total_num_workers: int,\n        ranks_taskgraph_mapping: Dict[int, Optional[\"TaskGraph\"]],\n        relevant_node_types_param: Optional[Set[NodeType]] = None,\n    ):\n        \"\"\"Initializes the ProcessGroupManager.\n\n        Args:\n            total_num_workers: The total number of workers in the distributed setup.\n            ranks_taskgraph_mapping: A mapping from worker ranks to their assigned TaskGraph.\n            relevant_node_types_param: A set of NodeTypes to consider for group\n                formation. If None, it defaults to MODEL_INFERENCE and MODEL_TRAIN.\n\n        Raises:\n            ValueError: If total_num_workers is not positive or if\n                relevant_node_types_param has an invalid type.\n        \"\"\"\n        if total_num_workers <= 0:\n            raise ValueError(\"Total number of workers must be positive.\")\n        self.total_num_workers: int = total_num_workers\n        self.ranks_taskgraph_mapping: Dict[int, Optional[\"TaskGraph\"]] = dict(ranks_taskgraph_mapping)\n\n        # --- Internal State Mappings ---\n        # Maps a node ID to the sorted list of ranks that execute it.\n        self.node_ranks_mapping: Dict[str, List[int]] = {}\n        # Maps a process group name to its list of member ranks.\n        self.process_group_spec: Dict[str, List[int]] = {}\n        # Maps a node ID to its assigned process group name.\n        self.node_process_group_mapping: Dict[str, str] = {}\n        # Maps a node type (as a string) to the set of PGs associated with it.\n        self.node_type_process_group_mapping: Dict[str, Set[str]] = collections.defaultdict(set)\n        # Maps subgraph ID -> node type -> set of PG names.\n        self.subgraph_node_type_pg_mapping: Dict[str, Dict[str, Set[str]]] = collections.defaultdict(lambda: collections.defaultdict(set))\n\n        # Establish the set of node types that are relevant for process group creation.\n        if relevant_node_types_param is None:\n            self.relevant_node_types: Set[NodeType] = {\n                NodeType.MODEL_INFERENCE,\n                NodeType.MODEL_TRAIN,\n            }\n        else:\n            if not isinstance(relevant_node_types_param, set) or not all(isinstance(nt, NodeType) for nt in relevant_node_types_param):\n                raise ValueError(\"relevant_node_types_param must be a set of NodeType enums.\")\n            self.relevant_node_types: Set[NodeType] = relevant_node_types_param\n\n        self._compute_group_configurations()\n\n    def _clear_internal_mappings(self):\n        \"\"\"Resets all internal state dictionaries to an empty state.\"\"\"\n        self.node_ranks_mapping.clear()\n        self.process_group_spec.clear()\n        self.node_process_group_mapping.clear()\n        self.node_type_process_group_mapping.clear()\n        self.subgraph_node_type_pg_mapping.clear()\n\n    def _collect_initial_topology_info(\n        self,\n    ) -> Tuple[Dict[str, Set[int]], Dict[str, \"NodeType\"], Dict[str, Set[str]]]:\n        \"\"\"\n        Scans the rank-to-taskgraph mapping to build an initial understanding of the topology.\n\n        It focuses only on nodes whose types are in `self.relevant_node_types`.\n\n        Returns:\n            A tuple containing:\n            - graph_id_to_ranks: Mapping of a graph ID to the set of ranks running it.\n            - node_id_to_type: Mapping of a relevant node ID to its NodeType.\n            - graph_id_to_node_ids: Mapping of a graph ID to the set of relevant node IDs within it.\n        \"\"\"\n        graph_id_to_ranks = collections.defaultdict(set)\n        node_id_to_type: Dict[str, \"NodeType\"] = {}\n        graph_id_to_node_ids: Dict[str, Set[str]] = collections.defaultdict(set)\n        processed_graph_ids = set()\n\n        for rank, tg_instance in self.ranks_taskgraph_mapping.items():\n            if not tg_instance:\n                continue\n\n            gid = tg_instance.graph_id\n            has_relevant_node = False\n\n            # Process the structure of each unique graph only once.\n            if gid not in processed_graph_ids:\n                for node in tg_instance.nodes.values():\n                    if node.node_type in self.relevant_node_types:\n                        has_relevant_node = True\n                        graph_id_to_node_ids[gid].add(node.node_id)\n                        if node.node_id not in node_id_to_type:\n                            node_id_to_type[node.node_id] = node.node_type\n                if has_relevant_node:\n                    processed_graph_ids.add(gid)\n            # If graph was already processed, just check if it was deemed relevant.\n            elif gid in graph_id_to_node_ids:\n                has_relevant_node = True\n\n            # If the graph contains any relevant nodes, associate the current rank with it.\n            if has_relevant_node:\n                graph_id_to_ranks[gid].add(rank)\n\n        return graph_id_to_ranks, node_id_to_type, graph_id_to_node_ids\n\n    def _aggregate_ranks_for_nodes(\n        self,\n        graph_id_to_ranks: Dict[str, Set[int]],\n        graph_id_to_node_ids: Dict[str, Set[str]],\n    ) -> Dict[str, Set[int]]:\n        \"\"\"\n        Aggregates all ranks for each node based on the graph they belong to.\n\n        Returns:\n            A dictionary mapping each node ID to the complete set of ranks that execute it.\n        \"\"\"\n        node_id_to_final_ranks = collections.defaultdict(set)\n        for gid, nodes_in_graph in graph_id_to_node_ids.items():\n            ranks_for_gid = graph_id_to_ranks.get(gid, set())\n            if not ranks_for_gid:\n                continue\n            for node_id in nodes_in_graph:\n                node_id_to_final_ranks[node_id].update(ranks_for_gid)\n        return node_id_to_final_ranks\n\n    def _populate_node_rank_mappings(self, node_id_to_final_ranks: Dict[str, Set[int]]):\n        \"\"\"Populates `self.node_ranks_mapping` with sorted lists of ranks for determinism.\"\"\"\n        for nid, ranks_set in node_id_to_final_ranks.items():\n            self.node_ranks_mapping[nid] = sorted(list(ranks_set))\n\n    def _define_process_groups(self) -> Dict[Tuple[int, ...], str]:\n        \"\"\"\n        Defines process groups based on unique rank configurations found across all nodes.\n\n        Returns:\n            A mapping from a unique rank tuple to its generated process group name.\n        \"\"\"\n        # Find all unique combinations of ranks that need to communicate.\n        unique_rank_configs: Set[Tuple[int, ...]] = {tuple(ranks) for ranks in self.node_ranks_mapping.values()}\n\n        # Assign a unique, deterministic name to each unique rank configuration.\n        rank_config_to_group_name: Dict[Tuple[int, ...], str] = {}\n        for i, rank_tuple in enumerate(sorted(list(unique_rank_configs))):\n            group_name = f\"process_group_{i + 1}\"\n            self.process_group_spec[group_name] = list(rank_tuple)\n            rank_config_to_group_name[rank_tuple] = group_name\n        return rank_config_to_group_name\n\n    def _populate_final_node_and_type_assignments(\n        self,\n        rank_config_to_group_name: Dict[Tuple[int, ...], str],\n        node_id_to_type: Dict[str, \"NodeType\"],\n    ):\n        \"\"\"Populates the final mappings from node/type to process groups.\"\"\"\n        for nid, sorted_ranks_list in self.node_ranks_mapping.items():\n            rank_tuple = tuple(sorted_ranks_list)\n            group_name = rank_config_to_group_name.get(rank_tuple)\n            if not group_name:\n                continue\n\n            self.node_process_group_mapping[nid] = group_name\n            current_node_type = node_id_to_type.get(nid)\n            if current_node_type:\n                self.node_type_process_group_mapping[current_node_type.value].add(group_name)\n\n    def _populate_subgraph_node_type_process_group_mapping(\n        self,\n        graph_id_to_node_ids: Dict[str, Set[str]],\n        node_id_to_type: Dict[str, \"NodeType\"],\n    ):\n        \"\"\"Populates the granular mapping of (subgraph, node_type) -> set of PGs.\"\"\"\n        for gid, nodes_in_graph in graph_id_to_node_ids.items():\n            for node_id in nodes_in_graph:\n                node_type = node_id_to_type.get(node_id)\n                pg_name = self.node_process_group_mapping.get(node_id)\n\n                if node_type and pg_name:\n                    self.subgraph_node_type_pg_mapping[gid][node_type.value].add(pg_name)\n\n    def _compute_group_configurations(self):\n        \"\"\"\n        Orchestrates the step-by-step process of computing all process group configurations.\n        \"\"\"\n        self._clear_internal_mappings()\n\n        if not self.ranks_taskgraph_mapping:\n            logger.warning(\"Ranks to TaskGraph mapping is empty. No process groups to compute.\")\n            return\n\n        # Step 1: Discover the basic topology of graphs, nodes, and ranks.\n        graph_id_to_ranks, node_id_to_type, graph_id_to_node_ids = self._collect_initial_topology_info()\n\n        if not node_id_to_type:\n            logger.warning(\"No nodes of a relevant type found. No process groups formed.\")\n            return\n\n        # Step 2: Determine the full set of ranks for each node.\n        node_id_to_final_ranks = self._aggregate_ranks_for_nodes(graph_id_to_ranks, graph_id_to_node_ids)\n        self._populate_node_rank_mappings(node_id_to_final_ranks)\n\n        if not self.node_ranks_mapping:\n            return\n\n        # Step 3: Define unique process groups from the rank configurations.\n        rank_config_to_pg_name = self._define_process_groups()\n\n        # Step 4: Create the final mappings for nodes and types.\n        self._populate_final_node_and_type_assignments(rank_config_to_pg_name, node_id_to_type)\n        self._populate_subgraph_node_type_process_group_mapping(graph_id_to_node_ids, node_id_to_type)\n\n    # --- Public API Methods ---\n\n    def get_group_spec(self, group_name: str) -> Optional[Dict[str, Any]]:\n        \"\"\"Retrieves the specification (list of ranks) for a single process group.\"\"\"\n        ranks_list = self.process_group_spec.get(group_name)\n        return {\"ranks\": ranks_list} if ranks_list is not None else None\n\n    def get_all_specs(self) -> Dict[str, Dict[str, Any]]:\n        \"\"\"Retrieves all defined process group specifications.\"\"\"\n        return {name: {\"ranks\": ranks} for name, ranks in self.process_group_spec.items()}\n\n    def get_node_assignment(self, node_id: str) -> Optional[Dict[str, Any]]:\n        \"\"\"Retrieves the rank and process group assignment for a specific node.\"\"\"\n        if node_id in self.node_process_group_mapping:\n            return {\n                \"ranks\": self.node_ranks_mapping[node_id],\n                \"process_group_name\": self.node_process_group_mapping[node_id],\n            }\n        return None\n\n    def get_process_groups_for_node_type(self, node_type_value: str) -> Set[str]:\n        \"\"\"Gets all process groups associated with a given node type globally.\"\"\"\n        # Return only for types that were configured as relevant during initialization.\n        if any(rt.value == node_type_value for rt in self.relevant_node_types):\n            return self.node_type_process_group_mapping.get(node_type_value, set())\n        return set()\n\n    def get_process_group_for_node_type_in_subgraph(self, graph_id: str, node_type_value: str) -> Set[str]:\n        \"\"\"Gets all process groups for a node type within a specific subgraph.\"\"\"\n        # Return only for types that were configured as relevant during initialization.\n        if any(rt.value == node_type_value for rt in self.relevant_node_types):\n            return self.subgraph_node_type_pg_mapping.get(graph_id, {}).get(node_type_value, set())\n        return set()\n\n\n# ==============================================================================\n# Logging Utility Functions\n# ==============================================================================\n\n\ndef _format_ranks_for_logging(ranks: Optional[List[int]], detailed_printing: bool, threshold: int = 10) -> str:\n    \"\"\"\n    Formats a list of ranks for concise and readable logging.\n\n    If detailed_printing is True or the list is short, it lists all ranks.\n    Otherwise, it shows a compact range and count.\n\n    Args:\n        ranks: The list of integer ranks.\n        detailed_printing: Flag to force printing all ranks.\n        threshold: The number of ranks above which compact view is used.\n\n    Returns:\n        A formatted string representing the list of ranks.\n    \"\"\"\n    if not ranks:\n        return \"N/A\"\n    if detailed_printing or len(ranks) <= threshold:\n        return str(sorted(ranks))  # Sort for consistent output\n    else:\n        return f\"[{min(ranks)}...{max(ranks)}] (Count: {len(ranks)})\"\n\n\ndef _log_group_specs_report(pgm: ProcessGroupManager, detailed_printing: bool, threshold: int) -> List[str]:\n    \"\"\"Generates the log report for all process group specifications.\"\"\"\n    report_lines = []\n    all_specs = pgm.get_all_specs()\n    if not all_specs:\n        report_lines.append(\"No process group specifications found.\")\n        return report_lines\n\n    report_lines.append(\"All Process Group Specifications:\")\n    for group_name, spec in sorted(all_specs.items()):\n        ranks_str = _format_ranks_for_logging(spec.get(\"ranks\"), detailed_printing, threshold)\n        report_lines.append(f\"  - Group '{group_name}': Ranks {ranks_str}\")\n    return report_lines\n\n\ndef _log_node_assignments_report(\n    pgm: ProcessGroupManager,\n    nodes_to_query: Optional[List[str]],\n    detailed_printing: bool,\n    threshold: int,\n) -> List[str]:\n    \"\"\"Generates the log report for node-to-process-group assignments.\"\"\"\n    report_lines = []\n    all_mapped_nodes = sorted(pgm.node_process_group_mapping.keys())\n\n    # Determine which nodes to generate the report for.\n    if nodes_to_query is None:\n        nodes_for_report = all_mapped_nodes\n    else:\n        nodes_for_report = nodes_to_query\n\n    if not all_mapped_nodes:\n        report_lines.append(\"No relevant node assignments found.\")\n        return report_lines\n\n    report_lines.append(\"Node Assignments:\")\n    if nodes_to_query is None:\n        report_lines.append(f\"  (Logging all {len(all_mapped_nodes)} relevant node assignments)\")\n\n    for node_id in nodes_for_report:\n        assignment = pgm.get_node_assignment(node_id)\n        if assignment:\n            ranks_str = _format_ranks_for_logging(assignment.get(\"ranks\"), detailed_printing, threshold)\n            report_lines.append(f\"  - Node '{node_id}': Assigned to PG '{assignment['process_group_name']}', Ranks {ranks_str}\")\n        # Only report missing if it was specifically requested.\n        elif node_id in (nodes_to_query or []):\n            report_lines.append(f\"  - Node '{node_id}': No assignment found (or not a relevant node).\")\n\n    return report_lines\n\n\ndef _log_global_type_mappings_report(pgm: ProcessGroupManager, node_types_to_query: Optional[List[NodeType]]) -> List[str]:\n    \"\"\"Generates the log report for global node type to process group mappings.\"\"\"\n    report_lines = []\n\n    # Determine which node types to query.\n    if node_types_to_query is None:\n        types_for_report = sorted(pgm.relevant_node_types, key=lambda nt: nt.value)\n    else:\n        types_for_report = node_types_to_query\n\n    if not pgm.relevant_node_types:\n        report_lines.append(\"No node types were configured as relevant in the PGM.\")\n        return report_lines\n\n    report_lines.append(\"Process Groups per Node Type (Global):\")\n    if node_types_to_query is None:\n        report_lines.append(\"  (Logging for all configured relevant node types)\")\n\n    for node_type in types_for_report:\n        pg_names = pgm.get_process_groups_for_node_type(node_type.value)\n        if pg_names:\n            report_lines.append(f\"  - NodeType '{node_type.value}': Associated with PGs {sorted(list(pg_names))}\")\n        # Only report if the type was relevant but had no groups.\n        elif node_type in pgm.relevant_node_types:\n            report_lines.append(f\"  - NodeType '{node_type.value}': No PGs found.\")\n\n    return report_lines\n\n\ndef _log_subgraph_mappings_report(\n    pgm: ProcessGroupManager,\n    subgraphs_to_query: Optional[List[str]],\n    node_types_to_query: Optional[List[NodeType]],\n) -> List[str]:\n    \"\"\"Generates the log report for subgraph-specific node type mappings.\"\"\"\n    report_lines = []\n\n    # Determine which subgraphs and node types to query.\n    all_mapped_subgraphs = sorted(pgm.subgraph_node_type_pg_mapping.keys())\n    subgraphs_for_report = subgraphs_to_query or all_mapped_subgraphs\n    types_for_report = node_types_to_query or sorted(pgm.relevant_node_types, key=lambda nt: nt.value)\n\n    if not all_mapped_subgraphs:\n        report_lines.append(\"No subgraph-specific mappings found.\")\n        return report_lines\n\n    report_lines.append(\"Process Groups per NodeType within Subgraphs:\")\n    if subgraphs_to_query is None:\n        report_lines.append(f\"  (Logging for all {len(all_mapped_subgraphs)} available subgraphs)\")\n\n    for subgraph_id in subgraphs_for_report:\n        if subgraph_id not in pgm.subgraph_node_type_pg_mapping:\n            if subgraphs_to_query is not None:\n                report_lines.append(f\"  Subgraph ID: '{subgraph_id}' - No mappings found.\")\n            continue\n\n        report_lines.append(f\"  Subgraph ID: '{subgraph_id}'\")\n        found_any_pg = False\n        for node_type in types_for_report:\n            pg_names = pgm.get_process_group_for_node_type_in_subgraph(subgraph_id, node_type.value)\n            if pg_names:\n                report_lines.append(f\"    - NodeType '{node_type.value}': Associated with PGs {sorted(list(pg_names))}\")\n                found_any_pg = True\n\n        if not found_any_pg:\n            report_lines.append(f\"    No process groups found for any of the queried node types in this subgraph.\")\n\n    return report_lines\n\n\ndef log_process_group_manager_details(\n    pgm: ProcessGroupManager, specific_nodes_to_query: Optional[List[str]] = None, specific_node_types_to_query: Optional[List[NodeType]] = None, specific_subgraphs_to_query: Optional[List[str]] = None, detailed_rank_printing: bool = False, rank_print_threshold: int = 16, log_level: str = \"info\"\n):\n    \"\"\"\n    Collects details from a ProcessGroupManager and logs them in a structured, aggregated message.\n\n    This function provides a comprehensive snapshot of the PGM's state, including\n    all defined process groups, node assignments, and type-based mappings,\n    with options for controlling log verbosity.\n\n    Args:\n        pgm: The initialized ProcessGroupManager instance.\n        specific_nodes_to_query: Optional list of node IDs to query for assignments.\n        specific_node_types_to_query: Optional list of NodeTypes to query for PG associations.\n        specific_subgraphs_to_query: Optional list of subgraph IDs to query.\n        detailed_rank_printing: If True, prints all ranks; otherwise, uses a compact range for large lists.\n        rank_print_threshold: The list size above which compact rank printing is used.\n    \"\"\"\n    # NOTE: This function has been refactored into smaller helpers for clarity.\n    # The final log output remains identical to the original version.\n\n    header = [\n        \"\\n\\n--- Process Group Manager Details (Aggregated Log) ---\",\n        f\"Detailed Rank Printing: {detailed_rank_printing}, Threshold: {rank_print_threshold}\",\n        \"-\" * 50,\n    ]\n\n    # Generate each section of the report using helper functions.\n    specs_report = _log_group_specs_report(pgm, detailed_rank_printing, rank_print_threshold)\n    nodes_report = _log_node_assignments_report(pgm, specific_nodes_to_query, detailed_rank_printing, rank_print_threshold)\n    types_report = _log_global_type_mappings_report(pgm, specific_node_types_to_query)\n    subgraphs_report = _log_subgraph_mappings_report(pgm, specific_subgraphs_to_query, specific_node_types_to_query)\n\n    # Assemble the final log message.\n    full_report = header + specs_report + [\"-\" * 50] + nodes_report + [\"-\" * 50] + types_report + [\"-\" * 50] + subgraphs_report + [\"--- End of Process Group Manager Details (Aggregated Log) ---\\n\\n\"]\n\n    # Log the full report at the specified log level.\n    valid_levels = [\"INFO\", \"DEBUG\", \"WARNING\", \"ERROR\", \"CRITICAL\"]\n    if log_level.upper() not in valid_levels:\n        raise ValueError(f\"Invalid log level: {log_level}. Choose from {valid_levels}.\")\n    logger.log(log_level.upper(), \"\\n\".join(full_report))\n"
  },
  {
    "path": "siirl/execution/scheduler/ray_actor_manager.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nManages a group of Ray actors for distributed workloads, handling their\ncreation, lifecycle, and communication.\n\"\"\"\n\nimport os\nimport re\nimport time\nfrom enum import Enum\nfrom typing import Any, Dict, List, Optional, Tuple\nimport asyncio\n\nimport ray\nfrom loguru import logger\nfrom ray.actor import ActorHandle\nfrom ray.experimental.state.api import get_actor\nfrom ray.util import list_named_actors\n\nfrom siirl.utils.extras.device import get_device_name\nfrom siirl.execution.scheduler.process_group_manager import ProcessGroupManager\nfrom siirl.params import SiiRLArguments\nfrom siirl.engine.base_worker import RayClassWithInitArgs, RayResourcePool, WorkerGroup, get_random_string, sort_placement_group_by_node_ip\nfrom siirl.execution.dag import TaskGraph\nfrom siirl.dag_worker.dagworker import DAGWorker\nfrom siirl.execution.rollout_flow.multiturn.agent_loop import AgentLoopManager\n\n\nclass DistributedEnv(Enum):\n    \"\"\"Enumeration for distributed environment variable keys.\"\"\"\n\n    MASTER_ADDR = \"MASTER_ADDR\"\n    MASTER_PORT = \"MASTER_PORT\"\n    WORLD_SIZE = \"WORLD_SIZE\"\n    RANK = \"RANK\"\n    WG_PREFIX = \"WG_PREFIX\"\n    WG_BACKEND = \"WG_BACKEND\"\n    RAY_LOCAL_WORLD_SIZE = \"RAY_LOCAL_WORLD_SIZE\"\n    RAY_LOCAL_RANK = \"RAY_LOCAL_RANK\"\n    DGA_PROCESS_GROUP = \"DGA_PROCESS_GROUP\"\n\n\n# --- Constants ---\nACTOR_STATE_ALIVE = \"ALIVE\"\nREGISTER_CENTER_POLL_INTERVAL_S = 1\nREGISTER_CENTER_LOG_INTERVAL_S = 30\nRAY_BACKEND = \"ray\"\n\n\nclass RayActorManager(WorkerGroup):\n    \"\"\"\n    Manages the lifecycle of a group of distributed Ray actors (workers).\n\n    This class handles the creation of actors based on resource availability,\n    assigns ranks, and sets up the necessary environment variables for\n    distributed communication. It provides a unified interface to execute\n    methods synchronously or asynchronously across all or specific workers.\n\n    Attributes:\n        worker_names (List[str]): A list of the generated names for each actor.\n        master_address (str): The network address of the rank 0 worker.\n        master_port (str): The network port of the rank 0 worker.\n        workers (List[ActorHandle]): The list of Ray actor handles managed by this group.\n        world_size (int): The total number of workers in the group.\n    \"\"\"\n\n    def __init__(\n        self,\n        resource_pool: RayResourcePool,\n        base_config: SiiRLArguments,\n        process_manager: ProcessGroupManager,\n        rank_taskgraph_mapping: Dict[int, \"TaskGraph\"],\n        data_coordinator_handle: ActorHandle,\n        metric_worker_handle: ActorHandle,\n        bin_pack: bool = True,\n        name_prefix: Optional[str] = None,\n        ray_wait_register_center_timeout: int = 300,\n        device_name=\"cuda\",\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Initializes the RayActorManager.\n\n        Args:\n            resource_pool: The pool of resources available for placing actors.\n            base_config: Base configuration arguments for the workers.\n            process_manager: Manager for the distributed process group.\n            rank_taskgraph_mapping: Mapping of worker ranks to their task graphs.\n            data_coordinator_handle: Handle to the central DataCoordinator actor.\n            metric_worker_handle: The Ray actor handle for the Central Metric Worker.\n            bin_pack: If True, use strict packing strategy for placement groups.\n            name_prefix: A custom prefix for actor names. A random one is\n                         generated if None.\n            ray_wait_register_center_timeout: Seconds to wait for the rank 0\n                                              registration actor to appear.\n            **kwargs: Additional arguments for the base WorkerGroup.\n        \"\"\"\n        super().__init__(resource_pool=resource_pool, **kwargs)\n\n        self.name_prefix: str = get_random_string(length=6) if name_prefix is None else name_prefix\n        self._ray_wait_register_center_timeout = ray_wait_register_center_timeout\n\n        self._worker_names: List[str] = []\n        self._world_size: int = resource_pool.world_size\n        self._master_addr: Optional[str] = None\n        self._master_port: Optional[str] = None\n\n        self.base_config = base_config\n        self.process_manager = process_manager\n        self.rank_taskgraph_mapping = rank_taskgraph_mapping\n        self.data_coordinator_handle = data_coordinator_handle\n        self.metric_worker_handle = metric_worker_handle\n        self.device_name = device_name\n\n        # Prepare the Ray actor class with its initial arguments.\n        self.ray_actor_class = RayClassWithInitArgs(\n            ray.remote(DAGWorker),\n            config=self.base_config,\n            process_group_manager=self.process_manager,\n            taskgraph_mapping=self.rank_taskgraph_mapping,\n            data_coordinator=self.data_coordinator_handle,\n            metric_worker=self.metric_worker_handle,\n            device_name=self.device_name,\n        )\n\n        self._initialize_workers(resource_pool=resource_pool, bin_pack=bin_pack)\n\n    def _initialize_workers(self, resource_pool: RayResourcePool, bin_pack: bool) -> None:\n        \"\"\"\n        Creates and configures all worker actors based on the resource pool.\n\n        This method orchestrates the creation of placement groups, iterates through\n        them to launch actors with the correct rank and environment variables,\n        and establishes the master address for distributed coordination.\n        \"\"\"\n        strategy = \"STRICT_PACK\" if bin_pack else \"PACK\"\n        placement_groups = resource_pool.get_placement_groups(strategy=strategy, device_name=self.device_name)\n        sorted_pgs = sort_placement_group_by_node_ip(placement_groups)\n\n        num_gpus_per_worker = 1 / resource_pool.max_colocate_count\n        local_world_size = resource_pool.store[0]\n        rank = -1\n\n        for pg_index, placement_group in enumerate(sorted_pgs):\n            if local_world_size > placement_group.bundle_count:\n                raise ValueError(f\"Placement group for '{self.name_prefix}' has too few bundles ({placement_group.bundle_count}) to support the required local world size ({local_world_size}).\")\n\n            for local_rank in range(local_world_size):\n                rank += 1\n                worker = self._create_worker_actor(\n                    rank=rank,\n                    local_rank=local_rank,\n                    local_world_size=local_world_size,\n                    pg_index=pg_index,\n                    placement_group=placement_group,\n                    num_gpus_per_worker=num_gpus_per_worker,\n                    use_gpu=resource_pool.use_gpu,\n                    device_name=self.device_name,\n                )\n                self._workers.append(worker)\n\n                if rank == 0:\n                    # Rank 0 worker is special: it establishes the master\n                    # address and port for the entire worker group.\n                    self._master_addr, self._master_port = self._get_register_center_and_master_info()\n        work_futures = self.map_async(method_name=\"init_graph\")\n        ray.get(work_futures)\n        # only support single agent\n        \n    def _get_register_center_and_master_info(self) -> Tuple[str, str]:\n        \"\"\"\n        Waits for the registration actor to be available and fetches the\n        master address and port from it.\n\n        Returns:\n            A tuple containing the master address and master port.\n\n        Raises:\n            TimeoutError: If the registration actor cannot be found within the\n                          configured timeout.\n        \"\"\"\n        register_center_name = f\"{self.name_prefix}_register_center\"\n        logger.info(f\"Waiting for registration center actor: '{register_center_name}'...\")\n        start_time = time.time()\n\n        while time.time() - start_time < self._ray_wait_register_center_timeout:\n            try:\n                # Use list_named_actors for a more robust check.\n                if register_center_name in list_named_actors():\n                    register_center_actor = ray.get_actor(register_center_name)\n                    logger.success(f\"Successfully connected to '{register_center_name}'.\")\n                    rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote())\n\n                    master_addr = rank_zero_info[DistributedEnv.MASTER_ADDR.value]\n                    master_port = rank_zero_info[DistributedEnv.MASTER_PORT.value]\n                    return master_addr, master_port\n            except Exception as e:\n                logger.warning(f\"Attempt to get register center failed, will retry. Error: {e}\")\n\n            elapsed = int(time.time() - start_time)\n            if elapsed > 0 and elapsed % REGISTER_CENTER_LOG_INTERVAL_S == 0:\n                logger.warning(f\"Still waiting for '{register_center_name}'. Elapsed: {elapsed}s / {self._ray_wait_register_center_timeout}s.\")\n            time.sleep(REGISTER_CENTER_POLL_INTERVAL_S)\n\n        raise TimeoutError(\n            f\"Failed to get register_center_actor '{register_center_name}' \"\n            f\"within {self._ray_wait_register_center_timeout} seconds. \"\n            f\"Existing actors: {list_named_actors(all_namespaces=True)}. \"\n            \"Ensure Ray resources from previous runs are cleaned up or \"\n            \"increase 'trainer.ray_wait_register_center_timeout'.\"\n        )\n\n    def _create_worker_actor(\n        self,\n        rank: int,\n        local_rank: int,\n        local_world_size: int,\n        pg_index: int,\n        placement_group: \"ray.util.placement_group.PlacementGroup\",\n        num_gpus_per_worker: float,\n        use_gpu: bool,\n        device_name: str,\n    ) -> ActorHandle:\n        \"\"\"\n        Configures and creates a single worker actor.\n\n        Args:\n            rank: The global rank of the worker.\n            local_rank: The rank of the worker on its local node.\n            local_world_size: The number of workers on the local node.\n            pg_index: The index of the placement group being used.\n            placement_group: The Ray placement group for this worker.\n            num_gpus_per_worker: The number of GPUs to assign to the worker.\n            use_gpu: A boolean indicating if the worker should use a GPU.\n            device_name: The name of the device to use (\"cuda\" or \"npu\").\n\n        Returns:\n            The handle to the newly created Ray actor.\n        \"\"\"\n        # --- 1. Prepare Environment Variables ---\n        env_vars = {\n            DistributedEnv.WORLD_SIZE.value: str(self.world_size),\n            DistributedEnv.RANK.value: str(rank),\n            DistributedEnv.WG_PREFIX.value: self.name_prefix,\n            DistributedEnv.WG_BACKEND.value: RAY_BACKEND,\n            DistributedEnv.RAY_LOCAL_WORLD_SIZE.value: str(local_world_size),\n            DistributedEnv.RAY_LOCAL_RANK.value: str(local_rank),\n            DistributedEnv.DGA_PROCESS_GROUP.value: os.environ.get(\"DGA_PROCESS_GROUP\", \"\"),\n        }\n        if rank != 0:\n            if not self._master_addr or not self._master_port:\n                raise ConnectionError(\"Master address and port not set before creating non-zero rank workers.\")\n            env_vars[DistributedEnv.MASTER_ADDR.value] = self._master_addr\n            env_vars[DistributedEnv.MASTER_PORT.value] = self._master_port\n\n        # --- 2. Generate a unique and descriptive actor name ---\n        base_class_repr = type(self.ray_actor_class.cls).__name__  # e.g., \"ActorClass(DAGWorker)\"\n        match = re.search(r\"ActorClass\\(([^)]+)\\)\", base_class_repr)\n        actor_class_name = match.group(1) if match else base_class_repr\n        actor_name = f\"{self.name_prefix}_{actor_class_name}_{pg_index}:{local_rank}\"\n        self._worker_names.append(actor_name)\n\n        # --- 3. Set actor-specific options ---\n        self.ray_actor_class.update_options(\n            {\n                \"runtime_env\": {\"env_vars\": env_vars},\n                \"name\": actor_name,\n            }\n        )\n\n        # --- 4. Create the actor ---\n        logger.debug(f\"Creating actor '{actor_name}' with rank {rank}.\")\n        worker = self.ray_actor_class(placement_group=placement_group, placement_group_bundle_idx=local_rank, use_gpu=use_gpu, num_gpus=num_gpus_per_worker, device_name=device_name)\n        return worker\n\n    def _is_worker_alive(self, worker: ActorHandle) -> bool:\n        \"\"\"\n        Checks if a given worker actor is in the 'ALIVE' state.\n\n        Note: This uses a Ray experimental API, which may change in the future.\n\n        Args:\n            worker: The actor handle to check.\n\n        Returns:\n            True if the worker is alive, False otherwise.\n        \"\"\"\n        try:\n            worker_state_dict = get_actor(worker._actor_id.hex())\n            return worker_state_dict.get(\"state\", \"undefined\") == ACTOR_STATE_ALIVE if worker_state_dict else False\n        except Exception:\n            return False\n\n    def _invoke_on_worker(self, worker: ActorHandle, method_name: str, *args: Any, **kwargs: Any) -> Any:\n        \"\"\"Invokes a method on a single remote worker actor.\"\"\"\n        remote_method = getattr(worker, method_name)\n        return remote_method.remote(*args, **kwargs)\n\n    def map_sync(self, method_name: str, *args: Any, **kwargs: Any) -> List[Any]:\n        \"\"\"\n        Executes a method on all workers and waits for all results to complete.\n\n        This is the synchronous (blocking) version of `map_async`.\n\n        Args:\n            method_name: The name of the method to execute.\n            *args: Positional arguments for the method.\n            **kwargs: Keyword arguments for the method.\n\n        Returns:\n            A list containing the results from each worker.\n        \"\"\"\n        return ray.get(self.map_async(method_name, *args, **kwargs))\n\n    def map_async(self, method_name: str, *args: Any, **kwargs: Any) -> List[ray.ObjectRef]:\n        \"\"\"\n        Executes a method on all workers asynchronously.\n\n        Special Behavior:\n        If all positional and keyword arguments are lists of the same length\n        as the number of workers, the arguments are \"unzipped\" and distributed.\n        The i-th worker receives the i-th element from each argument list.\n        Otherwise, all workers receive the same arguments.\n\n        Args:\n            method_name: The name of the method to execute.\n            *args: Positional arguments for the method.\n            **kwargs: Keyword arguments for the method.\n\n        Returns:\n            A list of ray.ObjectRef handles to the results.\n        \"\"\"\n        num_workers = len(self._workers)\n\n        # Check for the special argument-splitting case\n        all_args_are_distributable = all(isinstance(arg, list) and len(arg) == num_workers for arg in args) and all(isinstance(val, list) and len(val) == num_workers for val in kwargs.values())\n\n        if all_args_are_distributable and (args or kwargs):\n            futures = []\n            for i in range(num_workers):\n                # Slice arguments for the i-th worker\n                sliced_args = tuple(arg[i] for arg in args)\n                sliced_kwargs = {k: v[i] for k, v in kwargs.items()}\n                futures.append(self._invoke_on_worker(self._workers[i], method_name, *sliced_args, **sliced_kwargs))\n            return futures\n\n        # Default case: all workers get the same arguments\n        return [self._invoke_on_worker(w, method_name, *args, **kwargs) for w in self._workers]\n\n    def map(self, method_name: str, *args: Any, **kwargs: Any) -> List[ray.ObjectRef]:\n        \"\"\"\n        Alias for `map_async`. Executes a method on all workers asynchronously.\n        \"\"\"\n        return self.map_async(method_name, *args, **kwargs)\n\n    @property\n    def worker_names(self) -> List[str]:\n        \"\"\"Returns the names of all managed workers.\"\"\"\n        return self._worker_names\n\n    @property\n    def master_address(self) -> Optional[str]:\n        \"\"\"Returns the address of the master (rank 0) worker.\"\"\"\n        return self._master_addr\n\n    @property\n    def master_port(self) -> Optional[str]:\n        \"\"\"Returns the port of the master (rank 0) worker.\"\"\"\n        return self._master_port\n\n    @property\n    def workers(self) -> List[ActorHandle]:\n        \"\"\"Returns the list of Ray actor handles.\"\"\"\n        return self._workers\n\n    @property\n    def world_size(self) -> int:\n        \"\"\"Returns the total number of workers.\"\"\"\n        return self._world_size\n"
  },
  {
    "path": "siirl/execution/scheduler/resource_manager.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\n# ==============================================================================\n# This file has been modified from the original version in the VERL library.\n# The original source code can be found at:\n# https://github.com/volcengine/verl\n#\n# Modifications Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n# ==============================================================================\n\nimport time\nimport ray\nfrom siirl.engine.base_worker import RayResourcePool\nfrom dataclasses import dataclass, field\nfrom loguru import logger\n\n\n@dataclass\nclass ResourcePoolManager:\n    \"\"\"\n    Define a resource pool specification. Resource pool will be initialized first.\n    \"\"\"\n\n    resource_pool_spec: dict[str, list[int]]\n    resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)\n\n    def create_resource_pool(self):\n        for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():\n            # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool\n            # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.\n            # For Megatron backend, we recommend using max_colocate_count>1\n            # that can utilize different WorkerGroup for differnt models\n            resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name)\n            self.resource_pool_dict[resource_pool_name] = resource_pool\n\n        self._check_resource_available()\n\n    def get_resource_pool(self, resource_pool_name: str) -> RayResourcePool:\n        return self.resource_pool_dict.get(resource_pool_name, None)\n\n    def get_n_gpus(self) -> int:\n        \"\"\"Get the number of gpus in this cluster.\"\"\"\n        return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])\n\n    def _check_resource_available(self, timeout=60, interval=1):\n        \"\"\"\n        Checks if the required resources are available in the Ray cluster,\n        waiting up to a specified timeout for nodes to become ready.\n\n        Args:\n            timeout (int): Maximum time to wait in seconds.\n            interval (int): Time to sleep between checks in seconds.\n        \"\"\"\n        logger.info(f\"Checking for available resources. Will wait for up to {timeout} seconds.\")\n        start_time = time.time()\n\n        # First, calculate the total required GPUs, which is a fixed value based on the spec.\n        total_required_gpus = sum(n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes)\n\n        while time.time() - start_time < timeout:\n            node_available_resources = ray.state.available_resources_per_node()\n            node_available_gpus = {\n                node: node_info.get(\"GPU\", 0) if \"GPU\" in node_info else node_info.get(\"NPU\", 0)\n                for node, node_info in node_available_resources.items()\n            }\n            # logger.success(node_available_resources)\n            total_available_gpus = sum(node_available_gpus.values())\n\n            if total_available_gpus >= total_required_gpus:\n                logger.success(f\"Total required GPUs ({total_required_gpus}) are available. Verifying placement possibility.\")\n                try:\n                    # Use a copy for the check to avoid modifying the original data during verification.\n                    self._verify_placement_possible(node_available_gpus.copy())\n                    logger.success(\"All resource pools can be satisfied. Proceeding.\")\n                    return  # Success, so exit the function immediately.\n                except ValueError as e:\n                    # Even if the total GPU count is sufficient, placement might fail due to resource distribution.\n                    # This situation usually does not change over time, so we fail fast.\n                    logger.error(f\"Placement check failed: {e}\")\n                    raise  # Raise the original placement error.\n\n            # If resources are not met, print a waiting message and sleep.\n            logger.info(f\"Waiting for nodes... Available GPUs: {total_available_gpus}/{total_required_gpus}. Retrying in {interval} seconds...\")\n            time.sleep(interval)\n\n        # If the loop finishes without meeting the condition, a timeout occurred.\n        final_available_gpus = sum(node_info.get(\"GPU\", 0) if \"GPU\" in node_info else node_info.get(\"NPU\", 0) for node_info in ray.state.available_resources_per_node().values())\n        error_msg = f\"Timed out after {timeout} seconds. The cluster does not have enough resources. Required: {total_required_gpus} GPUs, Available: {final_available_gpus} GPUs.\"\n        logger.error(error_msg)\n        raise TimeoutError(error_msg)\n\n    def _verify_placement_possible(self, available_gpus_per_node: dict):\n        \"\"\"\n        Checks if each resource pool can be satisfied with the current cluster topology.\n        This is a greedy check.\n\n        Args:\n            available_gpus_per_node (dict): A copy of the dictionary mapping node ID to its available GPU count.\n        \"\"\"\n        sorted_pools = sorted(self.resource_pool_spec.items(), key=lambda item: item[1][0], reverse=True)\n\n        for resource_pool_name, process_on_nodes in sorted_pools:\n            num_gpus_per_process, num_nodes = process_on_nodes[0], len(process_on_nodes)\n            found_nodes = 0\n            sorted_available_nodes = sorted(available_gpus_per_node.items(), key=lambda item: item[1], reverse=True)\n            temp_gpus_per_node = dict(sorted_available_nodes)\n\n            for node, available_gpus in temp_gpus_per_node.items():\n                if available_gpus >= num_gpus_per_process:\n                    temp_gpus_per_node[node] -= num_gpus_per_process\n                    found_nodes += 1\n                    if found_nodes == num_nodes:\n                        break\n\n            if found_nodes < num_nodes:\n                raise ValueError(f\"Resource pool '{resource_pool_name}' (requires {num_nodes} nodes with {num_gpus_per_process} GPUs each) cannot be satisfied by the current cluster resource distribution.\")\n\n            # If verification for this pool succeeds, update the main GPU availability dict for the next pool's check.\n            available_gpus_per_node.update(temp_gpus_per_node)\n"
  },
  {
    "path": "siirl/execution/scheduler/reward.py",
    "content": "# Copyright 2025 Individual Contributor: Thibaut Barroyer\n# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis module provides functionalities for dynamically loading and computing rewards\n\"\"\"\n\nimport importlib.util\nimport multiprocessing\nimport os\nimport sys\nfrom functools import partial\nfrom typing import Any, Callable, Dict, Optional, Tuple, Union\nfrom tensordict import TensorDict\nimport ray\nfrom loguru import logger\n\nfrom siirl.params import SiiRLArguments\nfrom siirl.utils.reward_score import default_compute_score\nfrom siirl.engine.reward_manager import (\n    DAPORewardManager,\n    NaiveRewardManager,\n    ParallelRewardManager,\n    EmbodiedRewardManager\n)\n\nTokenizer = Any\nRewardTensor = Any\nAnyRewardManager = Union[NaiveRewardManager, DAPORewardManager, ParallelRewardManager, EmbodiedRewardManager]\n\n\ndef load_custom_reward_function(config: SiiRLArguments) -> Optional[Callable]:\n    \"\"\"\n    Dynamically loads a custom reward function from a user-specified file.\n\n    This function reads the path and function name from the configuration,\n    imports the module, and returns the specified function.\n\n    Args:\n        config: The main SiiRLArguments configuration object which contains\n                the `custom_reward_function` settings.\n\n    Returns:\n        The loaded custom reward function wrapped with its configured keyword\n        arguments, or None if no custom function is specified.\n\n    Raises:\n        FileNotFoundError: If the specified Python file does not exist.\n        AttributeError: If the function is not found within the specified file.\n        RuntimeError: If the module cannot be loaded for other reasons.\n    \"\"\"\n    reward_fn_config = config.custom_reward_function\n    file_path = reward_fn_config.path\n\n    if not file_path:\n        return None\n\n    if not os.path.exists(file_path):\n        raise FileNotFoundError(f\"Custom reward function file not found: '{file_path}'\")\n\n    # Dynamically import the module from the given file path.\n    module_name = \"custom_module\"  # A placeholder name for the module.\n    spec = importlib.util.spec_from_file_location(module_name, file_path)\n    if spec is None or spec.loader is None:\n        raise RuntimeError(f\"Could not create module spec from '{file_path}'\")\n\n    module = importlib.util.module_from_spec(spec)\n    # This allows the module to be discoverable by other parts of the system\n    # if necessary, for instance during deserialization (unpickling).\n    sys.modules[module_name] = module\n\n    try:\n        spec.loader.exec_module(module)\n    except Exception as e:\n        raise RuntimeError(f\"Failed to execute module from '{file_path}': {e}\") from e\n\n    function_name = reward_fn_config.name\n    if not hasattr(module, function_name):\n        raise AttributeError(f\"Function '{function_name}' not found in custom reward file '{file_path}'.\")\n\n    logger.info(f\"Using custom reward function '{function_name}' from '{file_path}'\")\n    raw_fn = getattr(module, function_name)\n    reward_kwargs = dict(reward_fn_config.reward_kwargs)\n\n    # Wrap the function to pre-fill the custom keyword arguments.\n    return partial(raw_fn, **reward_kwargs)\n\n\ndef create_reward_manager(\n    config: SiiRLArguments,\n    tokenizer: Tokenizer,\n    num_examine: int,\n    **reward_kwargs,\n) -> AnyRewardManager:\n    \"\"\"\n    Factory function to instantiate and return the appropriate reward manager.\n\n    It selects the reward manager class based on the configuration and wires it\n    up with the correct scoring function, which can be a default, a sandbox-\n    based, or a custom function.\n\n    Args:\n        config: The SiiRLArguments configuration object.\n        tokenizer: The tokenizer instance to be used by the reward manager.\n        num_examine: The number of candidates to examine.\n        **reward_kwargs: Additional keyword arguments for the reward manager.\n\n    Returns:\n        An instantiated reward manager object.\n\n    Raises:\n        NotImplementedError: If the specified `reward_manager_name` is unknown.\n    \"\"\"\n    \n    # Map manager names to their respective classes for clean, extensible selection.\n    manager_map = {\n        \"naive\": NaiveRewardManager,\n        \"dapo\": DAPORewardManager,\n        \"parallel\": ParallelRewardManager,\n        \"embodied\": EmbodiedRewardManager\n    }\n\n    # Map each manager to its default compute_score function\n    # Note: compute_embodied_reward is imported lazily to avoid loading sklearn for LLM/VLM tasks\n    default_compute_score_map = {\n        \"naive\": default_compute_score,\n        \"prime\": default_compute_score,\n        \"batch\": default_compute_score,\n        \"dapo\": default_compute_score,\n        \"parallel\": default_compute_score,\n        \"embodied\": None,  # Will be loaded lazily if needed\n    }\n    reward_manager_name = config.reward_model.reward_manager\n    reward_manager_cls = manager_map.get(reward_manager_name)\n\n    if reward_manager_cls is None:\n        raise NotImplementedError(f\"Reward manager '{reward_manager_name}' is not implemented.\")\n\n    # Determine the final scoring function.\n    # Priority: Custom function > Sandbox function > Default function\n    compute_score_fn = load_custom_reward_function(config)\n\n    if compute_score_fn is None:\n        sandbox_config = config.reward_model.sandbox_fusion\n        sandbox_url = sandbox_config.get(\"url\") if sandbox_config else None\n\n        if sandbox_url:\n            logger.info(f\"Using sandbox-based reward scoring at URL: {sandbox_url}\")\n            # This semaphore should be managed carefully. Creating it here assumes\n            # this function is called once per worker/process.\n            manager = multiprocessing.Manager()\n            semaphore = manager.Semaphore(sandbox_config.get(\"max_concurrent\", 64))\n            compute_score_fn = partial(\n                default_compute_score,\n                sandbox_fusion_url=sandbox_url,\n                concurrent_semaphore=semaphore,\n            )\n        else:\n            # Fallback to the default scoring function.\n            compute_score_fn = default_compute_score_map.get(\n                reward_manager_name,\n                default_compute_score  # Fallback for any unmapped managers\n            )\n            \n            # Lazy import for embodied reward to avoid loading sklearn for LLM/VLM tasks\n            if compute_score_fn is None and reward_manager_name == \"embodied\":\n                from siirl.utils.reward_score.embodied import compute_embodied_reward\n                compute_score_fn = compute_embodied_reward\n                logger.info(\"Loaded embodied reward function (with sklearn dependencies)\")\n\n    return reward_manager_cls(\n        tokenizer=tokenizer,\n        num_examine=num_examine,\n        compute_score=compute_score_fn,\n        reward_fn_key=config.data.reward_fn_key,\n        **reward_kwargs,\n    )\n\n\ndef compute_reward(data: TensorDict, reward_fn: Callable) -> Tuple[RewardTensor, Dict[str, Any]]:\n    \"\"\"\n    Computes the reward for a given batch of data using the provided function.\n\n    This function includes robust error handling. If the reward function fails,\n    it logs a warning and returns a placeholder reward (e.g., zero) instead of\n    crashing.\n\n    Args:\n        data: A TensorDict object containing the batch of input data.\n        reward_fn: The reward function or manager method to call.\n\n    Returns:\n        A tuple containing:\n        - The reward tensor for the batch.\n        - A dictionary with any extra metadata returned by the reward function.\n    \"\"\"\n    try:\n        # Assumes reward_fn can return a dictionary with specific keys.\n        reward_result = reward_fn(data, return_dict=True)\n        reward_tensor = reward_result[\"reward_tensor\"]\n        extra_info = reward_result.get(\"reward_extra_info\", {})\n    except Exception:\n        # If the structured return fails, try a simpler call.\n        try:\n            reward_tensor = reward_fn(data)\n            extra_info = {}\n        except Exception as e:\n            logger.warning(f\"Error computing reward: {e}. Returning a zero tensor.\")\n            # Create a zero tensor of the expected shape as a fallback.\n            # This requires knowing the expected tensor type and shape.\n            # Assuming a shape of (batch_size,) and using a generic placeholder.\n            # This part may need adjustment based on the actual tensor library (torch/tf).\n            batch_size = len(data.prompts)  # Example of getting batch size\n            reward_tensor = [0.0] * batch_size  # Placeholder for actual tensor\n            extra_info = {}\n\n    return reward_tensor, extra_info\n"
  },
  {
    "path": "siirl/execution/scheduler/task_scheduler.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport collections\nimport re\nfrom typing import List, Dict, Callable, Optional, Tuple, Set\nfrom loguru import logger\nfrom siirl.execution.dag.node import Node, NodeType\nfrom siirl.execution.dag.task_graph import TaskGraph\nfrom siirl.execution.dag.task_loader import discover_and_split_parallel_paths\n\n\ndef _parse_model_params_string(params_value: any) -> float:\n    \"\"\"\n    Parses a model parameter value, which can be a number or a string with units (M/B/K).\n    E.g.: \"70B\" -> 70 * 10^9, \"500M\" -> 500 * 10^6, \"100K\" -> 100 * 10^3.\n    \"\"\"\n    if isinstance(params_value, (int, float)):\n        return float(params_value)\n    if isinstance(params_value, str):\n        params_value_upper = params_value.upper()\n        # Regex to extract the numerical part (integer or float)\n        num_part_match = re.match(r\"^\\d+(\\.\\d+)?\", params_value_upper)\n        if not num_part_match:\n            logger.warning(f\"Could not parse numerical part from string '{params_value}'. Defaulting to 0.\")\n            return 0.0\n\n        num = float(num_part_match.group(0))\n\n        if params_value_upper.endswith(\"B\"):\n            return num * 1e9  # Billion\n        elif params_value_upper.endswith(\"M\"):\n            return num * 1e6  # Million\n        elif params_value_upper.endswith(\"K\"):\n            return num * 1e3  # Thousand\n        else:\n            # If no explicit unit, but the string can be converted to a float (e.g., \"1000000\")\n            try:\n                return float(params_value)  # Try to convert the whole string if no unit\n            except ValueError:\n                logger.warning(f\"Unrecognized model parameter unit or format '{params_value}'. Defaulting to 0.\")\n                return 0.0\n    logger.warning(f\"Unknown model parameter type '{type(params_value)}' value '{params_value}'. Defaulting to 0.\")\n    return 0.0\n\n\ndef estimate_graph_model_params(task_graph: TaskGraph) -> float:\n    \"\"\"\n    Estimates the 'size' of a task graph, typically based on the sum of model parameters\n    for MODEL_TRAIN or MODEL_INFERENCE nodes.\n    If no such nodes exist, it defaults to the number of nodes.\n    If model nodes exist but have no 'model_params' in config, it defaults to 0.\n    'model_params' can be a number, or a string like \"70B\" (70 billion) or \"500M\" (500 million).\n\n    Args:\n        task_graph (TaskGraph): The task graph to estimate.\n\n    Returns:\n        float: The estimated size of the graph.\n    \"\"\"\n    total_params: float = 0.0\n    if not task_graph or not task_graph.nodes:\n        return 0.0\n\n    has_model_nodes = False\n    has_positive_params_in_model_nodes = False\n    for node in task_graph.nodes.values():\n        if node.node_type in [NodeType.MODEL_TRAIN, NodeType.MODEL_INFERENCE]:\n            has_model_nodes = True\n            # Get 'model_params' from node config, default to 0.0 if not present\n            raw_params_value = node.config.get(\"model_params\", 0.0)\n            node_params = _parse_model_params_string(raw_params_value)  # Use helper to parse params\n\n            if node_params > 0:\n                has_positive_params_in_model_nodes = True\n            total_params += node_params\n\n    # If there are no model-related nodes, use the number of nodes as a proxy for size.\n    if not has_model_nodes:\n        return float(len(task_graph.nodes))\n    # If there are model nodes but none have positive parameter counts,\n    # this implies they are model-related but their sizes are unknown or zero.\n    # Returning 0.0 signals that it's a model graph of unknown/zero parameter size.\n    if has_model_nodes and not has_positive_params_in_model_nodes:\n        return 0.0\n\n    return total_params\n\n\nclass TaskScheduler:\n    \"\"\"\n    Schedules tasks (represented as TaskGraphs) and assigns them to a distributed set of workers,\n    each potentially having multiple GPUs. It aims to distribute tasks efficiently based on\n    their size and cohesion (parts of the same task on the same node).\n    \"\"\"\n\n    def __init__(self, num_physical_nodes: int, gpus_per_node: int):\n        \"\"\"\n        Initializes the TaskScheduler.\n\n        Args:\n            num_physical_nodes (int): The total number of physical compute nodes (machines).\n            gpus_per_node (int): The number of GPUs available on each physical node.\n\n        Raises:\n            ValueError: If num_physical_nodes or gpus_per_node is not positive.\n        \"\"\"\n        if num_physical_nodes <= 0:\n            raise ValueError(\"Number of physical nodes must be positive.\")\n        if gpus_per_node <= 0:\n            raise ValueError(\"GPUs per node must be positive.\")\n\n        self.num_physical_nodes: int = num_physical_nodes\n        self.gpus_per_node: int = gpus_per_node\n        self.num_workers: int = num_physical_nodes * gpus_per_node  # Total available worker slots (GPUs)\n\n        # State variables, reset for each scheduling call\n        self.worker_to_graph_assignment: Dict[int, Optional[TaskGraph]] = {}  # Maps worker rank to assigned TaskGraph\n        self.node_active_worker_count: Dict[int, int] = collections.defaultdict(int)  # Physical node index -> count of active workers\n        self.node_free_gpus: Dict[int, List[int]] = collections.defaultdict(list)  # Physical node index -> list of free worker ranks (GPUs) on that node\n        self._reset_scheduler_state()\n\n    def _reset_scheduler_state(self):\n        \"\"\"Resets the internal state of the scheduler, typically before a new scheduling pass.\"\"\"\n        self.worker_to_graph_assignment = {r: None for r in range(self.num_workers)}\n        self.node_active_worker_count = collections.defaultdict(int)\n        self.node_free_gpus = collections.defaultdict(list)\n        # Initialize free GPUs for each physical node\n        for worker_rank in range(self.num_workers):\n            physical_node_idx = worker_rank // self.gpus_per_node  # Determine which physical node this worker (GPU) belongs to\n            self.node_free_gpus[physical_node_idx].append(worker_rank)\n\n        for node_idx in self.node_free_gpus:\n            self.node_free_gpus[node_idx].sort()  # Keep free GPU ranks sorted, useful for consistent tie-breaking\n\n    def _get_original_graph_id(self, task_graph: TaskGraph) -> str:\n        \"\"\"\n        Extracts the original graph ID from a potentially modified (e.g., split) TaskGraph ID.\n        This is primarily used for logging or tracking purposes now, as inter-subgraph affinity\n        is no longer a direct scheduling factor.\n\n        Args:\n            task_graph (TaskGraph): The TaskGraph whose original ID is to be determined.\n\n        Returns:\n            str: The original graph ID.\n        \"\"\"\n        # Example naming from discover_and_split_parallel_paths: f\"{base_id_for_final_naming}_split_{idx}\"\n        # This function assumes such a convention.\n        parts = task_graph.graph_id.split(\"_split_\")\n        if len(parts) > 1:  # Covers cases like \"origID_final_1\", \"origID_sub_1_split_2\" etc.\n            base_id_candidate = parts[0]\n            return base_id_candidate\n        return task_graph.graph_id  # Return full ID if pattern doesn't match\n\n    def _apportion_workers_to_tasks(self, task_graphs_with_estimated_sizes: List[Tuple[TaskGraph, float]], num_total_workers_to_assign: int, apportion_strategy: str) -> Dict[str, int]:\n        \"\"\"\n        Distributes a total number of workers among a list of tasks based on a chosen strategy.\n        Ensures that if num_total_workers_to_assign >= num_tasks, each task gets at least one worker,\n        with remaining workers distributed according to the strategy.\n\n        Args:\n            task_graphs_with_estimated_sizes (List[Tuple[TaskGraph, float]]): A list of tuples,\n                each containing a TaskGraph and its estimated size. These are the tasks that *will* run.\n            num_total_workers_to_assign (int): The total number of workers to distribute among these tasks.\n            apportion_strategy (str): The strategy for apportioning workers ('even' or 'param_aware').\n\n        Returns:\n            Dict[str, int]: A dictionary mapping task_graph.graph_id to the number of workers assigned.\n        \"\"\"\n        num_tasks_to_run = len(task_graphs_with_estimated_sizes)\n        # Initialize apportionment: graph_id -> num_workers\n        apportionment: Dict[str, int] = {tg.graph_id: 0 for tg, _ in task_graphs_with_estimated_sizes}\n\n        if num_tasks_to_run == 0 or num_total_workers_to_assign == 0:\n            return apportionment  # No tasks or no workers, nothing to apportion.\n\n        workers_assigned_so_far = 0\n        # Step 1: Assign one worker to each task that is set to run, if enough workers are available.\n        # This ensures every task selected to run gets at least minimal resources.\n        if num_total_workers_to_assign >= num_tasks_to_run:\n            for graph, _ in task_graphs_with_estimated_sizes:\n                apportionment[graph.graph_id] = 1\n            workers_assigned_so_far = num_tasks_to_run\n\n        # Calculate remaining workers to be distributed after the initial assignment (if any).\n        remaining_workers_to_distribute = num_total_workers_to_assign - workers_assigned_so_far\n\n        # Step 2: Distribute any remaining workers based on the chosen strategy.\n        if remaining_workers_to_distribute > 0:\n            if apportion_strategy == \"even\":\n                # Distribute remaining workers as evenly as possible on top of the initial ones.\n                if num_tasks_to_run == 0:  # Avoid division by zero if no tasks to run (already handled, but defensive)\n                    return apportionment\n                base_additional_workers = remaining_workers_to_distribute // num_tasks_to_run\n                remainder_additional_workers = remaining_workers_to_distribute % num_tasks_to_run\n                # Sort by graph ID for deterministic distribution of remainder, though original order is usually fine.\n                sorted_graphs_for_remainder = sorted(task_graphs_with_estimated_sizes, key=lambda x: x[0].graph_id)\n                for i, (graph, _) in enumerate(sorted_graphs_for_remainder):\n                    apportionment[graph.graph_id] += base_additional_workers + (1 if i < remainder_additional_workers else 0)\n\n            elif apportion_strategy == \"param_aware\":\n                # Calculate total size for weighting, excluding tasks with zero or negative estimated size.\n                total_size_for_weights = sum(size for _, size in task_graphs_with_estimated_sizes if size > 0)\n\n                if total_size_for_weights > 0 and num_tasks_to_run > 0:  # Ensure num_tasks_to_run > 0 for modulo\n                    # Greedily assign remaining workers one by one to tasks, prioritizing larger tasks.\n                    # Tasks are sorted by size (desc) to pick recipients for extra workers.\n                    # This ensures that larger tasks are favored when distributing the remainder.\n                    sorted_tasks_by_size_desc = sorted(task_graphs_with_estimated_sizes, key=lambda x: x[1], reverse=True)\n\n                    temp_rem_workers = remaining_workers_to_distribute\n                    worker_idx_counter = 0\n                    # Distribute remaining workers one by one, cycling through tasks sorted by size.\n                    while temp_rem_workers > 0:\n                        # Cycle through tasks (largest first) to give them additional workers.\n                        task_to_give_worker_tuple = sorted_tasks_by_size_desc[worker_idx_counter % num_tasks_to_run]\n                        apportionment[task_to_give_worker_tuple[0].graph_id] += 1\n                        temp_rem_workers -= 1\n                        worker_idx_counter += 1\n                else:\n                    # If all tasks have zero/negative size or no tasks, fall back to even distribution for the remainder.\n                    if num_tasks_to_run == 0:\n                        return apportionment  # Avoid division by zero\n                    base_additional_workers = remaining_workers_to_distribute // num_tasks_to_run\n                    remainder_additional_workers = remaining_workers_to_distribute % num_tasks_to_run\n                    sorted_graphs_for_fallback = sorted(task_graphs_with_estimated_sizes, key=lambda x: x[0].graph_id)\n                    for i, (graph, _) in enumerate(sorted_graphs_for_fallback):\n                        apportionment[graph.graph_id] += base_additional_workers + (1 if i < remainder_additional_workers else 0)\n            else:\n                raise ValueError(f\"Unknown apportionment strategy: {apportion_strategy}\")\n\n        # Sanity check: Ensure the total number of assigned workers matches the target.\n        current_sum_workers = sum(apportionment.values())\n        if current_sum_workers != num_total_workers_to_assign:\n            # This indicates a potential logic error in apportionment.\n            logger.error(f\"Apportionment sum mismatch. Expected {num_total_workers_to_assign}, got {current_sum_workers}. Apportionment map: {apportionment}\")\n            return {}\n            # Corrective action might be needed here in a production system.\n        return apportionment\n\n    def schedule_and_assign_tasks(\n        self, original_task_graphs: List[TaskGraph], size_estimator: Callable[[TaskGraph], float] = estimate_graph_model_params, apportion_strategy: str = \"param_aware\", consider_node_cohesion: bool = True, consider_node_load: bool = True, consider_rank_preference: bool = True\n    ) -> Dict[int, Optional[TaskGraph]]:\n        \"\"\"\n        Schedules a list of original TaskGraphs by first splitting them into irreducible subgraphes,\n        then assigning these subgraphes to workers (GPUs) across physical nodes.\n\n        The method raises a ValueError if not all schedulable subgraphes can be assigned at least one worker.\n\n        Args:\n            original_task_graphs (List[TaskGraph]): A list of original TaskGraph objects to be scheduled.\n            size_estimator (Callable[[TaskGraph], float]): A function that estimates the 'size' of a TaskGraph.\n            apportion_strategy (str): Strategy for distributing workers among tasks.\n            consider_node_cohesion (bool): If True, tries to schedule workers for the same\n                                           irreducible subgraph onto the same physical compute node.\n            consider_node_load (bool): If True (default), placement will prefer physical nodes with lower current load.\n            consider_rank_preference (bool): If True (default), placement will prefer lower-ranked GPUs as a tie-breaker.\n\n\n        Returns:\n            Dict[int, Optional[TaskGraph]]: A dictionary mapping worker rank to its assigned TaskGraph.\n\n        Raises:\n            ValueError: If the number of schedulable subgraphes exceeds the number of available workers,\n                        making it impossible to assign at least one worker to each.\n        \"\"\"\n        self._reset_scheduler_state()  # Initialize scheduler state for a new run.\n\n        if not original_task_graphs:\n            logger.info(\"No original TaskGraphs provided for scheduling. All workers will be idle.\")\n            return self.worker_to_graph_assignment  # Return empty assignment if no tasks.\n\n        # Step 1: Split original task graphs into irreducible subgraphes.\n        # Irreducible subgraphes are the actual units of work that will be scheduled.\n        all_irreducible_subgraphes: List[TaskGraph] = []\n        for i, original_graph in enumerate(original_task_graphs):\n            if not original_graph or not original_graph.nodes:\n                logger.warning(f\"Original graph at index {i} (ID: {original_graph.graph_id if original_graph else 'N/A'}) is empty. Skipping.\")\n                continue\n            # discover_and_split_parallel_paths breaks down complex graphs.\n            subgraphes = discover_and_split_parallel_paths(original_graph)\n            all_irreducible_subgraphes.extend(subgraphes)\n\n        if not all_irreducible_subgraphes:\n            logger.info(\"No schedulable irreducible subgraphes were derived. All workers will be idle.\")\n            return self.worker_to_graph_assignment  # Return empty if splitting results in no subgraphes.\n\n        # Step 2: Estimate sizes of irreducible subgraphes and sort them.\n        # Sorting by size (descending) helps in prioritizing larger tasks if not all can run,\n        # or in the 'param_aware' apportionment strategy.\n        subgraphes_with_sizes_sorted: List[Tuple[TaskGraph, float]] = sorted(\n            [(sg, size_estimator(sg)) for sg in all_irreducible_subgraphes],\n            key=lambda x: x[1],  # Sort by estimated size.\n            reverse=True,  # Largest tasks first.\n        )\n\n        num_schedulable_subgraphes = len(subgraphes_with_sizes_sorted)\n        workers_per_task_map: Dict[str, int]  # Map: subgraph_id -> number of workers assigned to it.\n\n        # Step 3: Determine how many workers each subgraph gets.\n        # Crucially, if not all tasks can be assigned at least one worker, raise an error.\n        if num_schedulable_subgraphes > self.num_workers:\n            raise ValueError(f\"Cannot assign all tasks. Number of schedulable subgraphes ({num_schedulable_subgraphes}) exceeds the total number of available workers ({self.num_workers}). Please provide more workers or reduce the number of tasks/subgraphes.\")\n        else:\n            # All schedulable subgraphes will run.\n            # Apportion all available workers (self.num_workers) among these subgraphes.\n            tasks_to_run_with_sizes = subgraphes_with_sizes_sorted  # All of them are considered for running.\n            workers_per_task_map = self._apportion_workers_to_tasks(\n                tasks_to_run_with_sizes,  # All schedulable subgraphes.\n                self.num_workers,  # Total workers to distribute among them.\n                apportion_strategy,\n            )\n\n        # Filter to get only tasks that were actually assigned workers (should be all in this logic path).\n        tasks_actually_running_with_sizes = [(tg, size) for tg, size in tasks_to_run_with_sizes if workers_per_task_map.get(tg.graph_id, 0) > 0]\n        # Order tasks for placement: prioritize tasks needing more workers, then by size.\n        # This can influence placement if certain nodes become full.\n        tasks_to_place_ordered = sorted(\n            tasks_actually_running_with_sizes,\n            key=lambda x: (workers_per_task_map.get(x[0].graph_id, 0), x[1]),\n            # Sort by num_workers_for_task then by size.\n            reverse=True,  # Tasks needing more workers/larger tasks first.\n        )\n\n        # Step 4: Place each worker for each scheduled subgraph.\n        for task_graph_to_place, _ in tasks_to_place_ordered:  # The estimated size is not directly used in this loop.\n            subgraph_id = task_graph_to_place.graph_id\n            num_workers_for_this_subgraph = workers_per_task_map.get(subgraph_id, 0)\n\n            if num_workers_for_this_subgraph == 0:\n                # This should not happen if the logic above correctly assigns workers.\n                logger.warning(f\"Subgraph {subgraph_id} was allocated 0 workers. Skipping placement.\")\n                continue\n\n            # Keep track of workers assigned to *this specific subgraph instance* for cohesion calculation.\n            workers_assigned_to_current_subgraph_instance: List[int] = []\n\n            # Assign each of the required workers for the current subgraph.\n            for worker_slot_index in range(num_workers_for_this_subgraph):\n                best_worker_rank_for_slot: int = -1\n                # Scoring tuple for placement: (cohesion_score, node_load_score, rank_preference_score)\n                # Higher scores are better. node_load and rank_preference are negative (lower is better).\n                best_placement_score: Tuple[float, float, float] = (float(\"-inf\"), float(\"-inf\"), float(\"-inf\"))\n\n                # Determine cohesion targets: physical nodes already running other workers for THIS specific subgraph instance.\n                intra_task_cohesion_target_nodes: Set[int] = set()\n                if consider_node_cohesion and workers_assigned_to_current_subgraph_instance:\n                    for r_assigned_to_this_task in workers_assigned_to_current_subgraph_instance:\n                        intra_task_cohesion_target_nodes.add(r_assigned_to_this_task // self.gpus_per_node)\n\n                # Iterate over all physical compute nodes to find the best free GPU for the current slot.\n                for physical_node_idx in range(self.num_physical_nodes):\n                    if not self.node_free_gpus[physical_node_idx]:\n                        continue  # No free GPUs on this physical node.\n\n                    # Consider the first available (e.g., lowest rank) GPU on this node.\n                    # Sorting of node_free_gpus[node_idx] in _reset_scheduler_state ensures this is deterministic.\n                    potential_worker_rank = self.node_free_gpus[physical_node_idx][0]\n\n                    # --- Calculate scores for this potential placement ---\n                    # Cohesion score:\n                    cohesion_score = 0.0\n                    if consider_node_cohesion:\n                        if not workers_assigned_to_current_subgraph_instance:\n                            cohesion_score = 1.0  # Any node is fine for the first worker from cohesion perspective.\n                        elif physical_node_idx in intra_task_cohesion_target_nodes:\n                            cohesion_score = 1.0  # Placing with its peers.\n\n                    # Node load score:\n                    node_load_score_component = 0.0  # Default to 0 if not considering load.\n                    if consider_node_load:\n                        node_load_score_component = -float(self.node_active_worker_count[physical_node_idx])  # Negative: lower load is better.\n\n                    # Rank preference score:\n                    rank_score_component = 0.0  # Default to 0 if not considering rank preference.\n                    if consider_rank_preference:\n                        rank_score_component = -float(potential_worker_rank)  # Negative: lower rank is better (converted to float for type consistency in tuple).\n\n                    current_placement_score: Tuple[float, float, float] = (cohesion_score, node_load_score_component, rank_score_component)\n\n                    # If current placement is better than the best found so far, update.\n                    if current_placement_score > best_placement_score:\n                        best_placement_score = current_placement_score\n                        best_worker_rank_for_slot = potential_worker_rank\n\n                # Assign the subgraph to the best found worker slot.\n                if best_worker_rank_for_slot != -1:\n                    chosen_physical_node_idx = best_worker_rank_for_slot // self.gpus_per_node\n                    self.worker_to_graph_assignment[best_worker_rank_for_slot] = task_graph_to_place\n                    self.node_active_worker_count[chosen_physical_node_idx] += 1  # Increment active worker count for the chosen node.\n                    self.node_free_gpus[chosen_physical_node_idx].remove(best_worker_rank_for_slot)  # Mark GPU as used.\n                    workers_assigned_to_current_subgraph_instance.append(best_worker_rank_for_slot)\n                else:\n                    # This error implies a logic flaw if workers were apportioned but cannot be placed.\n                    original_id_for_logging = self._get_original_graph_id(task_graph_to_place)  # For better logging.\n                    logger.error(\n                        f\"Could not find any free worker to place for subgraph {subgraph_id} (Original: {original_id_for_logging}). \"\n                        f\"Worker slot {worker_slot_index + 1}/{num_workers_for_this_subgraph}. \"\n                        f\"Already placed for this subgraph: {workers_assigned_to_current_subgraph_instance}. \"\n                        f\"Total workers assigned so far: {sum(1 for w in self.worker_to_graph_assignment.values() if w is not None)}/{self.num_workers}. \"\n                        f\"Investigate scheduler logic or available resources.\"\n                    )\n                    break  # Stop trying to place workers for this subgraph if one fails critically.\n\n        # Final check: Ensure all workers are utilized if there were tasks.\n        final_assigned_worker_count = sum(1 for w_val in self.worker_to_graph_assignment.values() if w_val is not None)\n        if final_assigned_worker_count != self.num_workers and len(all_irreducible_subgraphes) > 0:\n            logger.warning(f\"Post-scheduling, {self.num_workers - final_assigned_worker_count} workers are unexpectedly idle despite having {len(all_irreducible_subgraphes)} schedulable subgraphes. Total workers assigned: {final_assigned_worker_count}/{self.num_workers}.\")\n\n        return self.worker_to_graph_assignment\n\n    def get_unique_assigned_task_graphs(self) -> Dict[str, TaskGraph]:\n        \"\"\"\n        Returns a list of unique TaskGraph objects that have been assigned to workers\n        as a result of the last scheduling pass. This list contains the irreducible\n        subgraphs that were actually scheduled.\n\n        Returns:\n            List[TaskGraph]: A list of unique TaskGraph objects. Returns an empty list\n                             if no tasks were scheduled or if the scheduler hasn't run.\n        \"\"\"\n        if not self.worker_to_graph_assignment:\n            return {}\n\n        unique_graphs_map: Dict[str, TaskGraph] = {}\n        for task_graph in self.worker_to_graph_assignment.values():\n            if task_graph:  # Filter out None values (idle workers)\n                # The graph_id of subgraphs generated by discover_and_split_parallel_paths\n                # is unique, making it a reliable key for identifying unique TaskGraph instances.\n                unique_graphs_map[task_graph.graph_id] = task_graph\n\n        return unique_graphs_map\n\n\ndef _format_ranks_for_logging(ranks: Optional[List[int]], detailed_rank_printing: bool, threshold: int = 10) -> str:\n    \"\"\"\n    Formats a list of ranks for logging.\n    If detailed_rank_printing is True, or if the number of ranks is below threshold,\n    it prints all ranks. Otherwise, it prints a range and count.\n    \"\"\"\n    if not ranks:\n        return \"N/A\"\n    # Ensure ranks are sorted for consistent output, especially for range.\n    # Make a copy before sorting if the original list should not be modified,\n    # though in this context, the lists are usually temporary.\n    sorted_ranks = sorted(list(set(ranks)))  # Remove duplicates and sort\n\n    if detailed_rank_printing or len(sorted_ranks) <= threshold:\n        return str(sorted_ranks)\n    else:\n        if not sorted_ranks:  # Should not happen if ranks is not None and not empty\n            return \"[] (Empty after sort/unique)\"\n        return f\"[{min(sorted_ranks)} ... {max(sorted_ranks)}] (Count: {len(sorted_ranks)})\"\n\n\ndef log_schedule_assignments(\n    assignments: Dict[int, Optional[\"TaskGraph\"]],\n    num_total_workers: int,\n    detailed_rank_printing: bool = False,  # New parameter\n    rank_print_threshold: int = 10,  # New parameter\n) -> None:\n    \"\"\"\n    Clearly logs the results of task scheduling and assignment using loguru,\n    with an option for concise rank printing.\n\n    Args:\n        assignments (Dict[int, Optional[TaskGraph]]):\n            A dictionary where keys are worker ranks (int) and values are\n            the assigned TaskGraph object or None.\n        num_total_workers (int):\n            The total number of available workers in the system.\n        detailed_rank_printing (bool): If True, prints all ranks.\n                                     Otherwise, uses range for large lists.\n        rank_print_threshold (int): The threshold above which ranks are printed as a range.\n    \"\"\"\n    if not isinstance(assignments, dict):\n        # Use logger.error for actual errors, info for informational messages.\n        logger.error(\"Input for printing schedule assignments must be a dictionary.\")\n        return\n\n    log_messages: List[str] = [\"\\n\\n--- Task Schedule Assignment Results ---\", f\"Detailed Rank Printing: {detailed_rank_printing}, Threshold: {rank_print_threshold}\"]  # Collect all log parts here\n\n    # 1. Group workers by TaskGraph\n    task_to_workers_map: Dict[str, List[int]] = collections.defaultdict(list)\n    idle_workers: List[int] = []\n\n    # Iterate up to num_total_workers to correctly identify all idle workers\n    all_assigned_ranks = set()\n    for worker_rank, assigned_graph in assignments.items():\n        if worker_rank < num_total_workers:  # Ensure we only consider valid worker ranks\n            all_assigned_ranks.add(worker_rank)\n            if assigned_graph and hasattr(assigned_graph, \"graph_id\"):\n                task_to_workers_map[assigned_graph.graph_id].append(worker_rank)\n            # else: # This rank is assigned None or an invalid object, consider idle if not in a task\n            #    pass # Handled by the loop below\n\n    for r in range(num_total_workers):\n        if r not in all_assigned_ranks or assignments.get(r) is None:  # Check if rank is truly idle\n            is_assigned_to_task = False\n            for _, assigned_ranks_for_task in task_to_workers_map.items():\n                if r in assigned_ranks_for_task:\n                    is_assigned_to_task = True\n                    break\n            if not is_assigned_to_task:\n                idle_workers.append(r)\n\n    for graph_id_key in task_to_workers_map:  # Sort ranks within each task's list\n        task_to_workers_map[graph_id_key].sort()\n    idle_workers.sort()  # Sort idle worker list\n\n    # 2. Prepare summary information\n    num_assigned_workers = 0\n    for worker_list in task_to_workers_map.values():\n        num_assigned_workers += len(worker_list)\n    # Correctly calculate idle workers based on total workers and those assigned tasks\n    num_idle_workers = num_total_workers - num_assigned_workers\n\n    num_scheduled_tasks = len(task_to_workers_map)\n\n    log_messages.append(f\"Total Workers: {num_total_workers}\")\n    log_messages.append(f\"Workers with Assigned Tasks: {num_assigned_workers}\")\n    log_messages.append(f\"Idle Workers: {num_idle_workers} (Derived from total - assigned)\")\n    log_messages.append(f\"Number of Scheduled TaskGraphs (Subgraphs): {num_scheduled_tasks}\")\n\n    # 3. Prepare detailed assignment for each TaskGraph\n    if task_to_workers_map:\n        log_messages.append(\"\\nDetailed Assignments:\")\n        for graph_id, worker_ranks in sorted(task_to_workers_map.items()):\n            ranks_str = _format_ranks_for_logging(worker_ranks, detailed_rank_printing, rank_print_threshold)\n            log_messages.append(f\"  TaskGraph (Subgraph ID): {graph_id}\")\n            log_messages.append(f\"    Assigned Worker Count: {len(worker_ranks)}\")\n            log_messages.append(f\"    Worker Ranks: {ranks_str}\")\n    else:\n        log_messages.append(\"\\nNo TaskGraphs were assigned to any workers.\")\n\n    # 4. Prepare idle workers information\n    # Use the derived idle_workers list for more accuracy if assignments dict might be sparse\n    actual_idle_worker_ranks = [r for r in range(num_total_workers) if all(r not in wr for wr in task_to_workers_map.values())]\n    actual_idle_worker_ranks.sort()\n\n    if actual_idle_worker_ranks:\n        ranks_str = _format_ranks_for_logging(actual_idle_worker_ranks, detailed_rank_printing, rank_print_threshold)\n        log_messages.append(\"\\nIdle Worker Ranks:\")\n        log_messages.append(f\"  Ranks: {ranks_str} (Count: {len(actual_idle_worker_ranks)})\")\n    elif num_total_workers > 0:  # No idle workers, and there are workers in the system\n        log_messages.append(\"\\nNo idle workers.\")\n\n    if num_total_workers == 0:\n        log_messages.append(\"\\nSystem has no workers.\")\n\n    log_messages.append(\"--- End of Assignment Results ---\\n\\n\")\n\n    # Log all messages as a single multi-line info block\n    logger.debug(\"\\n\".join(log_messages))\n\n\n# --- Example Usage ---\nif __name__ == \"__main__\":\n    # Setup for creating dummy TaskGraph objects for testing\n    def create_dummy_graph(graph_id: str, num_nodes: int, model_params: any = 0.0, dependencies_map: Optional[Dict[int, List[int]]] = None) -> TaskGraph:  # model_params type changed to any\n        \"\"\"Creates a TaskGraph for testing.\"\"\"\n        dummy_graph = TaskGraph(graph_id=graph_id)\n        nodes_to_add = []\n        for i in range(num_nodes):\n            node_type_val = NodeType.COMPUTE\n            node_config_val = {}\n            current_node_id = f\"{graph_id}_n{i}\"\n            node_deps: List[str] = []\n\n            if dependencies_map and i in dependencies_map:\n                node_deps = [f\"{graph_id}_n{dep_idx}\" for dep_idx in dependencies_map[i]]\n\n            # Put model_params directly into config, let estimate_graph_model_params handle parsing\n            if model_params != 0.0 and i == 0:  # For simplicity, assign params to the first node\n                node_type_val = NodeType.MODEL_TRAIN\n                node_config_val = {\"model_params\": model_params}  # Use passed model_params directly\n\n            nodes_to_add.append(Node(node_id=current_node_id, node_type=node_type_val, config=node_config_val, dependencies=node_deps))\n\n        if nodes_to_add:\n            dummy_graph.add_nodes(nodes_to_add)\n            dummy_graph.build_adjacency_lists()  # Important for validation and splitting\n            is_valid, msg = dummy_graph.validate_graph()\n            if not is_valid:\n                logger.warning(f\"Created dummy graph {graph_id} is invalid: {msg}\")\n        return dummy_graph\n\n    # --- Scheduler Configuration ---\n    num_physical_compute_nodes = 2  # e.g., 2 machines\n    gpus_per_compute_node = 4  # e.g., 4 GPUs per machine\n    scheduler = TaskScheduler(num_physical_nodes=num_physical_compute_nodes, gpus_per_node=gpus_per_compute_node)\n    # Total workers = 2 * 4 = 8\n\n    logger.info(f\"--- Initialized Scheduler: {scheduler.num_physical_nodes} Physical Nodes, {scheduler.gpus_per_node} GPUs/Node, Total Workers: {scheduler.num_workers} ---\")\n\n    # --- Scenario 1: Fewer original tasks than workers ---\n    logger.info(\"\\n--- Scenario 1: 3 Original Tasks (irreducible), 8 Workers ---\")\n    original_tasks_scen1 = [\n        create_dummy_graph(\"Weather_Sys\", num_nodes=3, model_params=\"600B\"),\n        create_dummy_graph(\"NLP_Sys\", num_nodes=2, model_params=\"300M\"),\n        create_dummy_graph(\"Vision_Sys\", num_nodes=1, model_params=100.0),\n    ]\n    configs_to_test_scene1 = [\n        {\"name\": \"Apportion:Even, Cohesion:Y, Load:Y, Rank:Y\", \"apportion\": \"even\", \"cohesion\": True, \"load\": True, \"rank\": True},\n        {\"name\": \"Apportion:Param, Cohesion:Y, Load:Y, Rank:Y\", \"apportion\": \"param_aware\", \"cohesion\": True, \"load\": True, \"rank\": True},\n        {\"name\": \"Apportion:Param, Cohesion:N, Load:Y, Rank:Y\", \"apportion\": \"param_aware\", \"cohesion\": False, \"load\": True, \"rank\": True},\n        {\"name\": \"Apportion:Param, Cohesion:Y, Load:N, Rank:Y\", \"apportion\": \"param_aware\", \"cohesion\": True, \"load\": False, \"rank\": True},\n        {\"name\": \"Apportion:Param, Cohesion:Y, Load:Y, Rank:N\", \"apportion\": \"param_aware\", \"cohesion\": True, \"load\": True, \"rank\": False},\n    ]\n    for cfg in configs_to_test_scene1:\n        logger.info(f\"\\n-- Config: {cfg['name']} --\")\n        try:\n            assignments = scheduler.schedule_and_assign_tasks(original_tasks_scen1, apportion_strategy=cfg[\"apportion\"], consider_node_cohesion=cfg[\"cohesion\"], consider_node_load=cfg[\"load\"], consider_rank_preference=cfg[\"rank\"])\n            workers_per_scheduled_task = collections.defaultdict(list)\n            for worker_rank, graph_obj in assignments.items():\n                if graph_obj:\n                    workers_per_scheduled_task[graph_obj.graph_id].append(worker_rank)\n\n            logger.info(\"  Workers per TaskGraph (Subgraph ID):\")\n            for task_id, worker_ranks in sorted(workers_per_scheduled_task.items()):\n                logger.info(f\"    {task_id}: {len(worker_ranks)} workers (Ranks: {sorted(worker_ranks)})\")\n            logger.info(f\"  Physical Node active worker counts: {dict(scheduler.node_active_worker_count)}\")\n            logger.info(f\"  Physical Node free GPUs: {{node_idx: gpu_ranks for node_idx, gpu_ranks in scheduler.node_free_gpus.items() if gpu_ranks}}\")\n        except ValueError as e:\n            logger.error(f\"  Error during scheduling: {e}\")\n\n    # --- Scenario 2: More original tasks than workers (tasks are simple/irreducible) ---\n    # This scenario should now raise a ValueError.\n    logger.info(\"\\n--- Scenario 2: 10 Original Tasks (irreducible), 8 Workers ---\")\n    original_tasks_scene2 = [create_dummy_graph(f\"T{i}_Job\", num_nodes=2, model_params=f\"{(10 - i) * 50}M\") for i in range(10)]\n\n    logger.info(f\"\\n-- Config: Apportion:Even, Cohesion:Y, Load:Y, Rank:Y (Expecting ValueError) --\")\n    try:\n        assignments_scene2 = scheduler.schedule_and_assign_tasks(original_tasks_scene2, apportion_strategy=\"even\", consider_node_cohesion=True, consider_node_load=True, consider_rank_preference=True)\n        workers_per_scheduled_task_scen2 = collections.defaultdict(list)\n        for worker_rank, graph_obj in assignments_scene2.items():\n            if graph_obj:\n                workers_per_scheduled_task_scen2[graph_obj.graph_id].append(worker_rank)\n\n        logger.info(\"  Workers per TaskGraph (Subgraph ID):\")  # Should not be reached\n        for task_id, worker_ranks in sorted(workers_per_scheduled_task_scen2.items()):\n            logger.info(f\"    {task_id}: {len(worker_ranks)} workers (Ranks: {sorted(worker_ranks)})\")\n        logger.info(f\"  Physical Node active worker counts: {dict(scheduler.node_active_worker_count)}\")\n    except ValueError as e:\n        logger.info(f\"  Successfully caught expected error: {e}\")\n\n    # --- Scenario 3: An original task that *can* be split ---\n    logger.info(\"\\n--- Scenario 3: 1 Original Task (splittable), 8 Workers ---\")\n    original_splittable_graph = TaskGraph(graph_id=\"ex1_reconverge_orig\")\n    original_splittable_graph.add_nodes(\n        [\n            Node(node_id=\"A\", node_type=NodeType.DATA_LOAD, config={\"model_params\": \"10M\"}),\n            Node(node_id=\"B\", node_type=NodeType.COMPUTE, dependencies=[\"A\"]),\n            Node(node_id=\"A1\", node_type=NodeType.DATA_LOAD, config={\"model_params\": \"10M\"}),\n            Node(node_id=\"B1\", node_type=NodeType.COMPUTE, dependencies=[\"A1\"]),\n            Node(node_id=\"C\", node_type=NodeType.COMPUTE, dependencies=[\"B\", \"B1\"], config={\"model_params\": \"50B\"}),\n            Node(node_id=\"D_ex1\", node_type=NodeType.COMPUTE, dependencies=[\"C\"]),\n            Node(node_id=\"E_ex1\", node_type=NodeType.MODEL_TRAIN, dependencies=[\"D_ex1\"], config={\"model_params\": \"100B\"}),\n        ]\n    )\n    original_splittable_graph.build_adjacency_lists()\n    is_valid, msg = original_splittable_graph.validate_graph()\n    if not is_valid:\n        logger.error(f\"Splittable graph is invalid: {msg}\")\n\n    original_tasks_scene3 = [original_splittable_graph]\n\n    logger.info(f\"\\n-- Config: Apportion:Param, Cohesion:Y, Load:Y, Rank:Y --\")\n    try:\n        assignments_scene3 = scheduler.schedule_and_assign_tasks(original_tasks_scene3, apportion_strategy=\"param_aware\", consider_node_cohesion=True, consider_node_load=True, consider_rank_preference=True)\n        workers_per_scheduled_task_scene3 = collections.defaultdict(list)\n        for worker_rank, graph_obj in assignments_scene3.items():\n            if graph_obj:\n                workers_per_scheduled_task_scene3[graph_obj.graph_id].append(worker_rank)\n\n        logger.info(\"  Workers per TaskGraph (Subgraph ID - after splitting ex1_reconverge_orig):\")\n        for task_id, worker_ranks in sorted(workers_per_scheduled_task_scene3.items()):\n            original_source_graph_for_task = TaskGraph(graph_id=task_id)\n            original_source = scheduler._get_original_graph_id(original_source_graph_for_task)\n            logger.info(f\"    {task_id} (from {original_source}): {len(worker_ranks)} workers (Ranks: {sorted(worker_ranks)})\")\n        logger.info(f\"  Physical Node active worker counts: {dict(scheduler.node_active_worker_count)}\")\n    except ValueError as e:\n        logger.error(f\"  Error during scheduling for splittable graph: {e}\")\n\n    # --- Scenario 4: No tasks ---\n    logger.info(\"\\n--- Scenario 4: 0 Original Tasks, 8 Workers ---\")\n    try:\n        assignments_scene4 = scheduler.schedule_and_assign_tasks([])\n        logger.info(f\"  Assignments (should be empty): {{k:v.graph_id if v else None for k,v in assignments_scen4.items() if v}}\")\n        logger.info(f\"  Physical Node active worker counts (should be all zeros): {dict(scheduler.node_active_worker_count)}\")\n    except ValueError as e:  # Should not happen for no tasks\n        logger.error(f\"  Error during scheduling for no tasks: {e}\")\n"
  },
  {
    "path": "siirl/main_dag.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport time\nfrom pathlib import Path\nimport ray\n\nfrom siirl.execution.scheduler.enums import AdvantageEstimator, AlgorithmType, WorkflowType\nfrom siirl.execution.scheduler.graph_updater import display_node_config, update_task_graph_node_configs\nfrom siirl.execution.scheduler.launch import RayTrainer\nfrom siirl.execution.scheduler.process_group_manager import ProcessGroupManager, log_process_group_manager_details\nfrom siirl.execution.scheduler.task_scheduler import TaskScheduler, log_schedule_assignments\nfrom siirl.utils.logger.logging_utils import set_basic_config\nfrom siirl.params import SiiRLArguments, log_dict_formatted, parse_config\nfrom siirl.execution.dag import TaskGraph\nfrom siirl.execution.dag.builtin_pipelines import grpo_pipeline, ppo_pipeline, dapo_pipeline, embodied_srpo_pipeline\nfrom siirl.data_coordinator.data_buffer import init_data_coordinator\nfrom siirl.execution.metric_worker.metric_worker import MetricWorker\n\n\n# --- Constants ---\nRAY_RUNTIME_ENV_VARS = {\n    \"TOKENIZERS_PARALLELISM\": \"true\",\n    \"NCCL_DEBUG\": \"WARN\",\n    \"VLLM_LOGGING_LEVEL\": \"WARN\",\n}\n\n# The main runner is an orchestrator, not a heavy workload.\n# Assigning it a full CPU is often wasteful. A fractional CPU is more efficient.\nMAIN_RUNNER_CPU_RESERVATION = 5\n\n\ndef _maybe_prepare_embodied_manifest(siirl_args: SiiRLArguments) -> None:\n    \"\"\"\n    Generate LIBERO manifests for embodied runs (srpo-compatible behavior).\n    \"\"\"\n    from loguru import logger\n\n    is_embodied_model = (\n        hasattr(siirl_args.actor_rollout_ref, \"model\")\n        and hasattr(siirl_args.actor_rollout_ref.model, \"model_type\")\n        and siirl_args.actor_rollout_ref.model.model_type == \"embodied\"\n    )\n    if not is_embodied_model:\n        return\n\n    # Lazy import to avoid requiring libero for non-embodied workflows\n    from siirl.data_coordinator.dataloader.embodied_preprocess import prepare_libero_train_valid_datasets\n\n    embodied = siirl_args.actor_rollout_ref.embodied\n    if embodied is None:\n        return\n    env = embodied.env\n\n    if not siirl_args.data.train_files:\n        raise ValueError(\n            \"For embodied training, `data.train_files` must be specified. \"\n            \"It is used as the output path for the generated task manifest.\"\n        )\n\n    output_dir = os.path.dirname(os.path.expanduser(siirl_args.data.train_files[0]))\n    logger.info(\"Embodied AI run detected. Generating task manifest...\")\n    train_file, valid_file = prepare_libero_train_valid_datasets(\n        task_suite_name=env.env_name,\n        num_trials_per_task=env.num_trials_per_task,\n        dataset_dir=output_dir,\n    )\n    siirl_args.data.train_files = [str(train_file)]\n    siirl_args.data.val_files = [str(valid_file)]\n    logger.success(f\"Task manifests generated and configured at: {output_dir}\")\n\n\ndef load_pipeline(siirl_args: SiiRLArguments) -> TaskGraph:\n    \"\"\"\n    Load training pipeline using the Python-based Pipeline API.\n\n    This function supports two modes (in priority order):\n    1. Custom pipeline via dag.custom_pipeline_fn (user-specified Python function)\n    2. Built-in Python pipelines (grpo_pipeline, ppo_pipeline, dapo_pipeline)\n\n    Args:\n        siirl_args: Configuration arguments\n\n    Returns:\n        TaskGraph: Loaded and validated task graph\n\n    Raises:\n        ImportError: If custom pipeline function cannot be loaded\n        NotImplementedError: If no suitable pipeline is found\n    \"\"\"\n    # Import logger locally to avoid Ray serialization issues\n    from loguru import logger\n\n    # Mode 1: User-specified custom pipeline function\n    if hasattr(siirl_args.dag, 'custom_pipeline_fn') and siirl_args.dag.custom_pipeline_fn:\n        logger.info(f\"Loading custom pipeline: {siirl_args.dag.custom_pipeline_fn}\")\n\n        try:\n            # Parse function path: \"module.path:function_name\"\n            if \":\" not in siirl_args.dag.custom_pipeline_fn:\n                raise ValueError(\n                    f\"Invalid custom_pipeline_fn format: '{siirl_args.dag.custom_pipeline_fn}'. \"\n                    f\"Expected format: 'module.path:function_name'\"\n                )\n\n            module_path, func_name = siirl_args.dag.custom_pipeline_fn.rsplit(\":\", 1)\n\n            # Dynamically import the module and function\n            import importlib\n            module = importlib.import_module(module_path)\n            pipeline_fn = getattr(module, func_name)\n\n            if not callable(pipeline_fn):\n                raise ValueError(\n                    f\"'{siirl_args.dag.custom_pipeline_fn}' is not callable. \"\n                    f\"It should be a function that returns TaskGraph.\"\n                )\n\n            # Call the function to get TaskGraph\n            taskgraph = pipeline_fn()\n\n            if not isinstance(taskgraph, TaskGraph):\n                raise ValueError(\n                    f\"Custom pipeline function '{siirl_args.dag.custom_pipeline_fn}' \"\n                    f\"must return a TaskGraph object, got {type(taskgraph)}\"\n                )\n\n            logger.success(f\"Custom pipeline loaded successfully: {taskgraph.graph_id}\")\n            return taskgraph\n\n        except Exception as e:\n            logger.error(f\"Failed to load custom pipeline '{siirl_args.dag.custom_pipeline_fn}': {e}\")\n            raise\n\n    # Mode 2: Built-in Python pipelines (default)\n    logger.info(f\"Using built-in Python pipeline for algorithm: {siirl_args.algorithm.adv_estimator}\")\n\n    # Set CPGD-specific config\n    if siirl_args.algorithm.adv_estimator == AdvantageEstimator.CPGD:\n        siirl_args.actor_rollout_ref.actor.use_cpgd_loss = True\n\n    # Select appropriate built-in pipeline\n    # Check algorithm_name first for special variants like DAPO (which may have adv_estimator=grpo)\n    workflow = siirl_args.algorithm.workflow_type\n    if workflow == WorkflowType.EMBODIED:\n        # Embodied AI workflows\n        if siirl_args.algorithm.adv_estimator == AdvantageEstimator.GAE:\n            raise ValueError(\n                f\"Unsupported adv_estimator '{siirl_args.algorithm.adv_estimator}' for Embodied AI. \"\n                f\"Use 'gae' for PPO or 'grpo' for GRPO.\"\n            )\n        elif siirl_args.algorithm.adv_estimator == AdvantageEstimator.GRPO:\n            return embodied_srpo_pipeline()\n        else:\n            raise ValueError(\n                f\"Unsupported adv_estimator '{siirl_args.algorithm.adv_estimator}' for Embodied AI. \"\n                f\"Use 'gae' for PPO or 'grpo' for GRPO.\"\n            )\n    elif workflow == WorkflowType.DAPO:\n        return dapo_pipeline()\n    elif workflow == WorkflowType.DEFAULT:\n        if siirl_args.algorithm.adv_estimator == AdvantageEstimator.GAE:\n            return ppo_pipeline()\n        else:  # For GRPO, GSPO, etc.\n            return grpo_pipeline()  # CPGD uses GRPO structure\n\n    else:\n        raise ValueError(f\"Unknown workflow_type: '{workflow}'\")\n    \n\n\n\n@ray.remote(num_cpus=MAIN_RUNNER_CPU_RESERVATION)\nclass MainRunner:\n    \"\"\"\n    A Ray actor responsible for orchestrating the entire RL training workflow.\n\n    This actor handles loading configurations, scheduling task graphs, initializing\n    process groups, and launching the distributed Ray trainers. Isolating this\n    orchestration logic in a dedicated actor ensures the main process remains clean\n    and that the setup process is managed within the Ray cluster.\n    \"\"\"\n\n    def run(self, siirl_args: SiiRLArguments) -> None:\n        \"\"\"\n        Executes the main training workflow.\n\n        Args:\n            siirl_args: A SiiRLArguments object containing all parsed configurations.\n        \"\"\"\n        set_basic_config()\n        from loguru import logger\n\n        logger.info(\"MainRunner started. Beginning workflow setup...\")\n        start_time = time.time()\n\n        # 1. Init DataBuffer\n        logger.info(f\"Initializing DataCoordinator with {siirl_args.trainer.nnodes} distributed DataBuffers...\")\n        # In the new architecture, the number of buffers is typically the number of nodes.\n        # We pass force_local=False to enable distributed deployment.\n        data_coordinator_handle = init_data_coordinator(\n            num_buffers=siirl_args.trainer.nnodes, ppo_mini_batch_size = siirl_args.actor_rollout_ref.actor.ppo_mini_batch_size,\n            world_size=siirl_args.trainer.nnodes * siirl_args.trainer.n_gpus_per_node\n        )\n\n        # 2. Load and configure the workflow task graph (DAG)\n        logger.info(\"Loading training pipeline...\")\n        workerflow_taskgraph = load_pipeline(siirl_args)\n        update_task_graph_node_configs(workerflow_taskgraph, siirl_args)\n        display_node_config(workerflow_taskgraph)\n\n        # 3. Schedule the task graph across available resources\n        logger.info(\"Scheduling tasks across nodes and GPUs...\")\n        total_workers = siirl_args.trainer.nnodes * siirl_args.trainer.n_gpus_per_node\n        task_scheduler = TaskScheduler(siirl_args.trainer.nnodes, siirl_args.trainer.n_gpus_per_node)\n        rank_taskgraph_mapping = task_scheduler.schedule_and_assign_tasks([workerflow_taskgraph])\n        log_schedule_assignments(rank_taskgraph_mapping, total_workers)\n        unique_graphs_map = task_scheduler.get_unique_assigned_task_graphs()\n\n        # 4. Create and configure process groups for communication\n        logger.info(\"Initializing process groups for distributed communication...\")\n        process_group_manager = ProcessGroupManager(total_workers, rank_taskgraph_mapping)\n        log_process_group_manager_details(process_group_manager, log_level=\"debug\")\n        # set process_group info into env for inference_actor\n        inference_process_group = []\n        inference_groups = process_group_manager.node_type_process_group_mapping[\"MODEL_INFERENCE\"]\n        for group_name in inference_groups:\n            inference_process_group.append(process_group_manager.process_group_spec[group_name])\n        os.environ[\"DGA_PROCESS_GROUP\"] = str(inference_process_group)\n        \n        # 5. Create Metric Worker\n        metric_worker_handle = MetricWorker.remote()\n        # 6. Initialize the main trainer\n        logger.info(\"Initializing RayTrainer...\")\n        trainer = RayTrainer(\n            config=siirl_args,\n            process_group_manager=process_group_manager,\n            rank_taskgraph_mapping=rank_taskgraph_mapping,\n            unique_graphs_map=unique_graphs_map,\n            data_coordinator_handle=data_coordinator_handle,\n            metric_worker_handle=metric_worker_handle,\n            device_name=siirl_args.trainer.device,\n        )\n\n        # 7. Initialize and start DAGWorkers\n        logger.info(\"Initializing and starting DAG workers...\")\n        trainer.init_workers()\n        trainer.start_workers()\n\n        setup_duration = time.time() - start_time\n        logger.info(f\"Workflow setup and worker launch complete. Time cost: {setup_duration:.2f}s\")\n\n\ndef main() -> None:\n    \"\"\"\n    Main entry point for launching the PPO DAG training job.\n\n    This function initializes Ray, parses configurations using Hydra, and\n    starts the MainRunner actor to orchestrate the distributed training workflow.\n\n    Args:\n        siirl_config: The configuration object provided by Hydra.\n    \"\"\"\n    # Import logger locally to avoid Ray serialization issues\n    from loguru import logger\n\n    start_time = time.time()\n\n    # Initialize Ray cluster if not already running\n    if not ray.is_initialized():\n        logger.info(\"Initializing local Ray cluster...\")\n        ray.init(runtime_env={\"env_vars\": RAY_RUNTIME_ENV_VARS}, num_cpus=None)\n    logger.success(f\"Ray is initialized. Time cost: {(time.time() - start_time) * 1000:.2f} ms\")\n\n    # Parse the complete configuration into a structured object\n    siirl_args = parse_config()\n    log_dict_formatted(siirl_args.to_dict(), \"SiiRLArguments\")\n\n    _maybe_prepare_embodied_manifest(siirl_args)\n\n    # Launch the main orchestration actor and wait for it to complete.\n    logger.info(\"Starting MainRunner actor to orchestrate the job.\")\n    runner = MainRunner.remote()\n    # This is a blocking call that waits for the remote `run` method to finish.\n    ray.get(runner.run.remote(siirl_args))\n\n    logger.success(\"MainRunner has completed its execution. Shutting down.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "siirl/models/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/models/embodied/openvla/__init__.py",
    "content": ""
  },
  {
    "path": "siirl/models/embodied/openvla/configuration_prismatic.py",
    "content": "\"\"\"\nconfiguration_prismatic.py\n\nHuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`.\nDefault configuration specifies `siglip-224px+7b`.\n\"\"\"\n\nfrom typing import Any, Dict, List, Optional\n\nfrom transformers import PretrainedConfig\nfrom transformers.models.auto import CONFIG_MAPPING\n\n# === Utilities for Mapping Prismatic names to HF names ===\n# fmt: off\nVISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = {\n    \"clip-vit-l\": [224], \"siglip-vit-so400m\": [224], \"dinov2-vit-l\": [224], \"in1k-vit-l\": [224],\n\n    \"clip-vit-l-336px\": [336],\n    \"siglip-vit-so400m-384px\": [384],\n\n    \"dinoclip-vit-l-336px\": [336, 336],\n    \"dinosiglip-vit-so-224px\": [224, 224],\n    \"dinosiglip-vit-so-384px\": [384, 384],\n}\nVISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = {\n    \"clip-vit-l\": [\"vit_large_patch14_clip_224.openai\"],\n    \"clip-vit-l-336px\": [\"vit_large_patch14_clip_336.openai\"],\n\n    \"dinov2-vit-l\": [\"vit_large_patch14_reg4_dinov2.lvd142m\"],\n    \"in1k-vit-l\": [\"vit_large_patch16_224.augreg_in21k_ft_in1k\"],\n\n    \"siglip-vit-so400m\": [\"vit_so400m_patch14_siglip_224\"],\n    \"siglip-vit-so400m-384px\": [\"vit_so400m_patch14_siglip_384\"],\n\n    \"dinoclip-vit-l-336px\": [\"vit_large_patch14_reg4_dinov2.lvd142m\", \"vit_large_patch14_clip_336.openai\"],\n    \"dinosiglip-vit-so-224px\": [\"vit_large_patch14_reg4_dinov2.lvd142m\", \"vit_so400m_patch14_siglip_224\"],\n    \"dinosiglip-vit-so-384px\": [\"vit_large_patch14_reg4_dinov2.lvd142m\", \"vit_so400m_patch14_siglip_384\"],\n}\nTIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = {\n    \"clip-vit-l\": [\"quick_gelu\"], \"clip-vit-l-336px\": [\"quick_gelu\"],\n    \"dinov2-vit-l\": [None], \"in1k-vit-l\": [None],\n    \"siglip-vit-so400m\": [None], \"siglip-vit-so400m-384px\": [None],\n    \"dinoclip-vit-l-336px\": [None, \"quick_gelu\"],\n    \"dinosiglip-vit-so-224px\": [None, None], \"dinosiglip-vit-so-384px\": [None, None]\n}\n\nLLM_BACKBONE_TO_HF_PATH = {\n    \"llama2-7b-pure\": \"meta-llama/Llama-2-7b-hf\", \"llama2-13b-pure\": \"meta-llama/Llama-2-13b-hf\",\n    \"llama2-7b-chat\": \"meta-llama/Llama-2-7b-chat-hf\", \"llama2-13b-chat\": \"meta-llama/Llama-2-13b-chat-hf\",\n\n    \"vicuna-v15-7b\": \"lmsys/vicuna-7b-v1.5\", \"vicuna-v15-13b\": \"lmsys/vicuna-13b-v1.5\",\n\n    \"mistral-v0.1-7b-pure\": \"mistralai/Mistral-7B-v0.1\",\n    \"mistral-v0.1-7b-instruct\": \"mistralai/Mistral-7B-Instruct-v0.1\",\n\n    \"phi-2-3b\": \"microsoft/phi-2\",\n}\nLLM_BACKBONE_TO_HF_METACLASS = {\n    \"llama2-7b-pure\": \"llama\", \"llama2-13b-pure\": \"llama\", \"llama2-7b-chat\": \"llama\", \"llama2-13b-chat\": \"llama\",\n    \"vicuna-v15-7b\": \"llama\", \"vicuna-v15-13b\": \"llama\",\n\n    \"mistral-v0.1-7b-pure\": \"mistral\", \"mistral-v0.1-7b-instruct\": \"mistral\",\n\n    \"phi-2-3b\": \"phi\",\n}\n\nVALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys())\nVALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH)\n# fmt: on\n\n\nclass PrismaticConfig(PretrainedConfig):\n    model_type: str = \"prismatic\"\n    is_composition: bool = False\n\n    def __init__(\n        self,\n        vision_backbone_id: str = \"siglip-vit-so400m\",\n        llm_backbone_id: str = \"vicuna-v15-7b\",\n        arch_specifier: str = \"no-align+gelu-mlp\",\n        use_fused_vision_backbone: Optional[bool] = None,\n        image_resize_strategy: str = \"letterbox\",\n        text_config: Optional[Dict[str, Any]] = None,\n        llm_max_length: int = 2048,\n        pad_token_id: int = 32000,\n        pad_to_multiple_of: int = 64,\n        output_projector_states: bool = False,\n        **kwargs: str,\n    ) -> None:\n        if vision_backbone_id not in VALID_VISION_BACKBONES:\n            raise ValueError(f\"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }\")\n\n        if llm_backbone_id not in VALID_LLM_BACKBONES:\n            raise ValueError(f\"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }\")\n\n        # Set Prismatic Configuration Fields\n        self.vision_backbone_id = vision_backbone_id\n        self.llm_backbone_id = llm_backbone_id\n        self.arch_specifier = arch_specifier\n        self.output_projector_states = output_projector_states\n\n        # [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing\n        self.use_fused_vision_backbone = (\n            use_fused_vision_backbone\n            if use_fused_vision_backbone is not None\n            else any(self.vision_backbone_id.startswith(v) for v in [\"dinoclip\", \"dinosiglip\"])\n        )\n\n        self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id]\n        self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id]\n        self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id]\n        self.image_resize_strategy = image_resize_strategy\n\n        self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id]\n        self.llm_max_length = llm_max_length\n        self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of\n\n        # [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming!\n        self.text_config = (\n            CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config)\n            if text_config is not None\n            else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]()\n        )\n\n        # Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well...\n        super().__init__(pad_token_id=pad_token_id, **kwargs)\n\n\nclass OpenVLAConfig(PrismaticConfig):\n    model_type: str = \"openvla\"\n\n    def __init__(\n        self,\n        norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None,\n        n_action_bins: int = 256,\n        **kwargs: str,\n    ) -> None:\n        self.norm_stats, self.n_action_bins = norm_stats, n_action_bins\n\n        super().__init__(**kwargs)\n"
  },
  {
    "path": "siirl/models/embodied/openvla/modeling_prismatic.py",
    "content": "\"\"\"\nmodeling_prismatic.py\n\nCore HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions, inheriting\nfrom the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, but exactly replicate the\nlogic in `prismatic.models.vlms.prismatic.py`.\n\nNote =>> for the time being, not adding the custom HF \"docstring\" formatting.\n\nReferences [LLaVa, IDEFICS-2]:\n    => https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py\n    => https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py\n\"\"\"\n\nimport logging\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport timm\nimport tokenizers\nimport torch\nimport torch.nn as nn\nimport transformers\nfrom timm.models.vision_transformer import LayerScale\nfrom transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel\nfrom transformers.modeling_outputs import ModelOutput\n\nfrom .configuration_prismatic import OpenVLAConfig, PrismaticConfig\n\n# Get Logger\nlogger = logging.getLogger(__name__)\n\n\n# === PyTorch/HuggingFace Default IGNORE_INDEX (for CrossEntropyLoss labels)\nIGNORE_INDEX = -100\n\n\n# === Utility Functions for Monkey-Patching ===\ndef unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:\n    def wrapper(*args: Any, **kwargs: Any) -> Any:\n        result = fn(*args, **kwargs)\n        return result[0] if isinstance(result, tuple) else result\n\n    return wrapper\n\n\n# HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.\n#   =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109\n#   =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960\ndef _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:\n    return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor\n\n\ndef ls_apply_patch(ls_module: LayerScale):\n    ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())\n    ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)\n    del ls_module.gamma\n\n\n# === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===\nclass PrismaticVisionBackbone(nn.Module):\n    def __init__(\n        self,\n        use_fused_vision_backbone: bool,\n        image_sizes: List[int],\n        timm_model_ids: List[str],\n        timm_override_act_layers: List[Optional[str]],\n    ) -> None:\n        super().__init__()\n        self.use_fused_vision_backbone = use_fused_vision_backbone\n\n        # [Contract] Validate number of (fused) vision backbones, create \"alpha\" featurizer and Instantiate\n        #   =>> Note :: Monkey-Patch the `forward()` function of the backbone to ensure FSDP-compatibility\n        #               Hardcodes `get_intermediate_layers` to return the **SECOND-TO-LAST** layer patches!\n        assert len(timm_model_ids) <= 2, \"Prismatic models only support up to 2 (fused) vision backbones!\"\n        self.featurizer = timm.create_model(\n            timm_model_ids[0],\n            pretrained=False,\n            num_classes=0,\n            img_size=image_sizes[0],\n            act_layer=timm_override_act_layers[0],\n        )\n        self.featurizer.forward = unpack_tuple(\n            partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2})\n        )\n        self.embed_dim = self.featurizer.embed_dim\n\n        # If `use_fused_vision_backbone` =>> create \"beta\" featurizer\n        if self.use_fused_vision_backbone:\n            self.fused_featurizer = timm.create_model(\n                timm_model_ids[1],\n                pretrained=False,\n                num_classes=0,\n                img_size=image_sizes[1],\n                act_layer=timm_override_act_layers[1],\n            )\n            self.fused_featurizer.forward = unpack_tuple(\n                partial(self.fused_featurizer.get_intermediate_layers, n={len(self.fused_featurizer.blocks) - 2})\n            )\n            self.embed_dim += self.fused_featurizer.embed_dim\n\n        # Patch `vision_backbone.featurizer` and `vision_backbone.fused_featurizer` with HF-Compatible LayerScale\n        for module in self.featurizer.modules():\n            if isinstance(module, LayerScale):\n                ls_apply_patch(module)\n\n        if self.use_fused_vision_backbone:\n            for module in self.fused_featurizer.modules():\n                if isinstance(module, LayerScale):\n                    ls_apply_patch(module)\n\n    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:\n        \"\"\"Run image (`pixel_values`) through featurizer; if channel-stacked, then dispatch and sequence stack.\"\"\"\n        if not self.use_fused_vision_backbone:\n            return self.featurizer(pixel_values)\n\n        # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack\n        img, img_fused = torch.split(pixel_values, [3, 3], dim=1)\n        patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)\n\n        return torch.cat([patches, patches_fused], dim=2)\n\n\n# === Prismatic Projector (nn.Module) Definitions ===\nclass PrismaticProjector(nn.Module):\n    def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:\n        super().__init__()\n        self.use_fused_vision_backbone = use_fused_vision_backbone\n        self.vision_dim, self.llm_dim = vision_dim, llm_dim\n\n        # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!\n        if not self.use_fused_vision_backbone:\n            self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)\n            self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)\n            self.act_fn1 = nn.GELU()\n        else:\n            initial_projection_dim = 4 * vision_dim\n            self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)\n            self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)\n            self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)\n            self.act_fn1 = nn.GELU()\n            self.act_fn2 = nn.GELU()\n\n    def forward(self, img_patches: torch.Tensor) -> torch.Tensor:\n        if not self.use_fused_vision_backbone:\n            projected_features = self.fc1(img_patches)\n            projected_features = self.act_fn1(projected_features)\n            projected_features = self.fc2(projected_features)\n        else:\n            projected_features = self.fc1(img_patches)\n            projected_features = self.act_fn1(projected_features)\n            projected_features = self.fc2(projected_features)\n            projected_features = self.act_fn2(projected_features)\n            projected_features = self.fc3(projected_features)\n\n        return projected_features\n\n\n# === Main HF Class Definitions ===\n@dataclass\nclass PrismaticCausalLMOutputWithPast(ModelOutput):\n    \"\"\"Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features.\"\"\"\n\n    loss: Optional[torch.FloatTensor] = None\n    logits: torch.FloatTensor = None\n    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None\n    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None\n    attentions: Optional[Tuple[torch.FloatTensor]] = None\n\n    # Additions for VLMs\n    projector_features: Optional[torch.FloatTensor] = None\n\n\nclass PrismaticPreTrainedModel(PreTrainedModel):\n    config_class: PretrainedConfig = PrismaticConfig\n    base_model_prefix: str = \"model\"\n    supports_gradient_checkpointing: bool = True\n\n    _no_split_modules: ClassVar[List[str]] = [\"PrismaticProjector\"]\n    _skip_keys_device_placement: str = \"past_key_values\"\n    _supports_flash_attn_2: bool = True\n\n    def _init_weights(self, module: nn.Module) -> None:\n        # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!\n        #   => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at\n        #      https://github.com/TRI-ML/prismatic-vlms\n        std = (\n            self.config.initializer_range\n            if hasattr(self.config, \"initializer_range\")\n            else self.config.text_config.initializer_range\n        )\n\n        if hasattr(module, \"class_embedding\"):\n            module.class_embedding.data.normal_(mean=0.0, std=std)\n\n        if isinstance(module, (nn.Linear, nn.Conv2d)):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n    @property\n    def _supports_sdpa(self) -> bool:\n        \"\"\"Check LLM supports SDPA Attention\"\"\"\n        return self.language_model._supports_sdpa\n\n\nclass PrismaticForConditionalGeneration(PrismaticPreTrainedModel):\n    def __init__(self, config: PrismaticConfig) -> None:\n        super().__init__(config)\n\n        # [Validation] Lightweight Validate on `config` Fields + Dependency Versions\n        if config.use_fused_vision_backbone is None:\n            raise ValueError(\"Missing config field `use_fused_vision_backbone`\")\n\n        if timm.__version__ not in {\"0.9.10\", \"0.9.11\", \"0.9.12\", \"0.9.16\"}:\n            raise NotImplementedError(\n                \"TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue \"\n                \"if you urgently need support for latest TIMM versions.\"\n            )\n\n        if (transformers.__version__ != \"4.40.1\") or (tokenizers.__version__ != \"0.19.1\"):\n            logger.warning(\n                f\"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got \"\n                f\"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; \"\n                f\"there might be inference-time regressions due to dependency changes. If in doubt, please\"\n                f\"use the above versions.\"\n            )\n\n        # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)\n        self.vision_backbone = PrismaticVisionBackbone(\n            config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers\n        )\n\n        # Create Multimodal Projector\n        self.projector = PrismaticProjector(\n            config.use_fused_vision_backbone,\n            vision_dim=self.vision_backbone.embed_dim,\n            llm_dim=config.text_config.hidden_size,\n        )\n\n        # Instantiate LLM Backbone\n        self.language_model = AutoModelForCausalLM.from_config(\n            config.text_config, attn_implementation=config._attn_implementation\n        )\n        self.vocab_size = config.text_config.vocab_size\n        self.pad_token_id = config.pad_token_id\n\n        # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing\n        self.post_init()\n\n    # === `PreTrainedModel` Boilerplate ===\n    def get_input_embeddings(self) -> nn.Module:\n        return self.language_model.get_input_embeddings()\n\n    def set_input_embeddings(self, value: nn.Module) -> None:\n        self.language_model.set_input_embeddings(value)\n\n    def get_output_embeddings(self) -> nn.Module:\n        return self.language_model.get_output_embeddings()\n\n    def set_output_embeddings(self, new_embeddings: nn.Module) -> None:\n        self.language_model.set_output_embeddings(new_embeddings)\n\n    def get_decoder(self) -> nn.Module:\n        return self.language_model.get_decoder()\n\n    def set_decoder(self, decoder: nn.Module) -> None:\n        self.language_model.set_decoder(decoder)\n\n    def tie_weights(self) -> None:\n        self.language_model.tie_weights()  # Note: `Llama-2` and `Mistral` don't tie weights (no-op)\n\n    def resize_token_embeddings(\n        self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None\n    ) -> nn.Embedding:\n        updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)\n\n        # Update config/instance variables\n        self.config.text_config.vocab_size = updated_embeddings.num_embeddings\n        self.vocab_size = updated_embeddings.num_embeddings\n\n        return updated_embeddings\n\n    # === Core Prismatic VLM `forward()` Logic ===\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        pixel_values: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_projector_features: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:\n        \"\"\"Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance.\"\"\"\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        output_projector_features = output_projector_features if output_projector_features is not None else False\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)\n        use_cache = use_cache and not self.training\n\n        # Instantiate Placeholder for Projector Features\n        projected_patch_embeddings = None\n\n        # Note :: We only support forward passes with the following cases:\n        #   => Cached Generation :: (input_ids.shape[1] == 1) and (past_key_values is not None)\n        #   => Unimodal Forward :: (pixel_values is None)\n        #   => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])\n\n        # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===\n        if input_ids.shape[1] == 1:\n            #assert input_ids.shape[0] == 1, \"Generation is only currently supported for batch size of 1!\"\n            assert past_key_values is not None, \"You must provide `past_key_values` during cached generation!\"\n            assert labels is None, \"Unexpected key `labels` provided during cached generation!\"\n\n            language_model_output = self.language_model(\n                input_ids=input_ids,\n                attention_mask=None,\n                position_ids=None,\n                past_key_values=past_key_values,\n                inputs_embeds=None,\n                labels=None,\n                use_cache=use_cache,\n                output_attentions=output_attentions,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n            )\n\n        # === Handle Unimodal Forward ===\n        elif pixel_values is None:\n            assert (input_ids is not None) and (inputs_embeds is None), \"Missing `input_ids` in language-only forward!\"\n            assert past_key_values is None, \"Unexpected key `past_key_values` provided during language-only forward!\"\n\n            language_model_output = self.language_model(\n                input_ids=input_ids,\n                attention_mask=attention_mask,\n                position_ids=None,\n                past_key_values=None,\n                inputs_embeds=None,\n                labels=labels,\n                use_cache=use_cache,\n                output_attentions=output_attentions,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n            )\n\n        # === Handle Multimodal Forward ===\n        elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):\n            assert past_key_values is None, \"Unexpected key `past_key_values` provided during language-only forward!\"\n\n            # Visual Feature Extraction\n            patch_features = self.vision_backbone(pixel_values)\n\n            # Projection Logic =>> Update Attention Mask\n            projected_patch_embeddings = self.projector(patch_features)\n            projected_patch_attention_mask = None\n            if attention_mask is not None:\n                projected_patch_attention_mask = torch.full(\n                    (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),\n                    fill_value=True,\n                    dtype=attention_mask.dtype,\n                    device=attention_mask.device,\n                )\n\n            # Get Input Embeddings (from Language Model Embeddings)\n            input_embeddings = self.get_input_embeddings()(input_ids)\n\n            # Build Multimodal Embeddings & Attention Mask =>> Prismatic defaults to inserting after <BOS> token (1:)\n            multimodal_embeddings = torch.cat(\n                [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1\n            )\n            multimodal_attention_mask = None\n            if attention_mask is not None:\n                multimodal_attention_mask = torch.cat(\n                    [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1\n                )\n\n            # Build Labels (if specified) =>> Ignore Labels for Patch Embeddings\n            multimodal_labels = None\n            if labels is not None:\n                projected_patch_labels = torch.full(\n                    (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),\n                    fill_value=IGNORE_INDEX,\n                    dtype=labels.dtype,\n                    device=labels.device,\n                )\n                multimodal_labels = torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)\n\n            # Dispatch to Language Model\n            language_model_output = self.language_model(\n                input_ids=None,\n                attention_mask=multimodal_attention_mask,\n                position_ids=None,\n                past_key_values=None,\n                inputs_embeds=multimodal_embeddings,\n                labels=multimodal_labels,\n                use_cache=use_cache,\n                output_attentions=output_attentions,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n            )\n\n        # === Otherwise =>> Assume Invalid! ===\n        elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):\n            raise ValueError(\"Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!\")\n\n        else:\n            raise ValueError(\n                \"Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\\n\"\n                f\"=> `input_ids` = {input_ids is not None}\\n\"\n                f\"=> `attention_mask` = {attention_mask is not None}\\n\"\n                f\"=> `pixel_values` = {pixel_values is not None}\\n\"\n                f\"=> `labels` = {labels is not None}\\n\"\n                f\"=> `input_embeds` = {inputs_embeds is not None}\\n\"\n                f\"=> `past_key_values` = {past_key_values is not None}\\n\"\n                f\"=> `use_cache` = {use_cache}\"\n            )\n\n        # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)\n        if not return_dict:\n            if output_projector_features and (projected_patch_embeddings is not None):\n                return *language_model_output, projected_patch_embeddings\n\n            return language_model_output\n\n        return PrismaticCausalLMOutputWithPast(\n            loss=language_model_output.loss,\n            logits=language_model_output.logits,\n            past_key_values=language_model_output.past_key_values,\n            hidden_states=language_model_output.hidden_states,\n            attentions=language_model_output.attentions,\n            projector_features=projected_patch_embeddings,\n        )\n\n    # === GenerationMixin Methods ===\n    def prepare_inputs_for_generation(\n        self,\n        input_ids: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        pixel_values: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        **kwargs: str,\n    ) -> Dict[str, torch.Tensor]:\n        \"\"\"Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic.\"\"\"\n        # if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (\n        #     (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)\n        # ):\n        #     raise ValueError(\"Generation with batch size > 1 is not currently supported!\")\n\n        # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens\n        if past_key_values is not None:\n            input_ids = input_ids[:, -1:]\n\n        # If `input_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"input_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        # Make sure `pixel_values` are preserved in `model_inputs`\n        model_inputs.update(\n            {\n                \"attention_mask\": attention_mask,\n                \"pixel_values\": pixel_values,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n            }\n        )\n\n        return model_inputs\n\n    # Defer to Language Model (all handle this differently, with different return types)\n    def _reorder_cache(self, *args, **kwargs) -> Any:\n        return self.language_model._reorder_cache(*args, **kwargs)\n\n\nclass OpenVLAForActionPrediction(PrismaticForConditionalGeneration):\n    config_class: PretrainedConfig = OpenVLAConfig\n\n    def __init__(self, config: OpenVLAConfig) -> None:\n        super().__init__(config)\n        self.norm_stats = config.norm_stats\n\n        # Compute action bins\n        self.bins = np.linspace(-1, 1, config.n_action_bins)\n        self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0\n\n        # Compute vocab size for de-tokenization -- revert added \"multiple of\"\n        self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of\n\n    def predict_action(\n        self, input_ids: Optional[torch.LongTensor] = None, unnorm_key: Optional[str] = None, **kwargs: str\n    ) -> np.ndarray:\n        \"\"\"Thin wrapper around .generate() that decodes predicted actions and unnormalizes them.\"\"\"\n        # If the special empty token ('') does not already appear after the colon (':') token in the prompt\n        # (after \"OUT:\" or \"ASSISTANT:\"), insert it to match the inputs seen at training time\n        if not torch.all(input_ids[:, -1] == 29871):\n            input_ids = torch.cat(\n                (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1\n            )\n\n        # Run VLA inference\n        generated_ids = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), **kwargs)\n\n        # Extract predicted action tokens and translate into (normalized) continuous actions\n        predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :].cpu().numpy()\n        discretized_actions = self.vocab_size - predicted_action_token_ids\n        discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)\n        normalized_actions = self.bin_centers[discretized_actions]\n\n        # Unnormalize actions\n        action_norm_stats = self.get_action_stats(unnorm_key)\n        mask = action_norm_stats.get(\"mask\", np.ones_like(action_norm_stats[\"q01\"], dtype=bool))\n        action_high, action_low = np.array(action_norm_stats[\"q99\"]), np.array(action_norm_stats[\"q01\"])\n        actions = np.where(\n            mask,\n            0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low,\n            normalized_actions,\n        )\n\n        return actions\n\n    @staticmethod\n    def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:\n        if unnorm_key is None:\n            assert len(norm_stats) == 1, (\n                f\"Your model was trained on more than one dataset, \"\n                f\"please pass a `unnorm_key` from the following options to choose the statistics \"\n                f\"used for un-normalizing actions: {norm_stats.keys()}\"\n            )\n            unnorm_key = next(iter(norm_stats.keys()))\n\n        assert unnorm_key in norm_stats, (\n            f\"The `unnorm_key` you chose is not in the set of available dataset statistics, \"\n            f\"please choose from: {norm_stats.keys()}\"\n        )\n        return unnorm_key\n\n    def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:\n        \"\"\"Get the dimensionality of the policy's action space.\"\"\"\n        unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)\n        return len(self.norm_stats[unnorm_key][\"action\"][\"q01\"])\n\n    def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:\n        \"\"\"Get all the logged statistics for the given dataset.\"\"\"\n        unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)\n        return self.norm_stats[unnorm_key][\"action\"]\n"
  },
  {
    "path": "siirl/models/embodied/openvla/processing_prismatic.py",
    "content": "\"\"\"\nprocessing_prismatic.py\n\nHuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration\nspecifies `siglip-224px+7b`.\n\"\"\"\n\nfrom typing import Any, ClassVar, List, Optional, Tuple, Union\n\nimport timm.data\nimport torch\nimport torchvision.transforms.functional as TVF\nfrom PIL import Image\nfrom torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor\nfrom transformers import PreTrainedTokenizerBase\nfrom transformers.image_processing_utils import BatchFeature, ImageProcessingMixin\nfrom transformers.processing_utils import ProcessorMixin\nfrom transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy\nfrom transformers.utils import TensorType\n\n\n# === Image Processing ===\ndef letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image:\n    \"\"\"Given a PIL.Image, pad to square by adding a symmetric border around the height/width.\"\"\"\n    (w, h), max_wh = image.size, max(image.size)\n    horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)\n    padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)\n\n    return TVF.pad(image, padding, fill=padding_fill_value, padding_mode=\"constant\")\n\n\nclass PrismaticImageProcessor(ImageProcessingMixin):\n    model_input_names: ClassVar[List[str]] = [\"pixel_values\"]\n\n    def __init__(\n        self,\n        use_fused_vision_backbone: bool = False,\n        image_resize_strategy: str = \"letterbox\",\n        input_sizes: Optional[List[Tuple[int, int, int]]] = None,\n        interpolations: Optional[List[str]] = None,\n        means: Optional[List[Tuple[float, float, float]]] = None,\n        stds: Optional[List[Tuple[float, float, float]]] = None,\n        **kwargs: str,\n    ) -> None:\n        \"\"\"\n        Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be\n        created by TIMM, and edited to follow our custom `image_resize_strategy` logic.\n\n        @param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone\n        @param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox >\n        @param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height)\n        @param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: \"bicubic\")\n        @param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`)\n        @param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`)\n        \"\"\"\n        self.use_fused_vision_backbone = use_fused_vision_backbone\n        self.image_resize_strategy = image_resize_strategy\n\n        # Handle `None` default values\n        input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes\n        means = [(0.5, 0.5, 0.5)] if means is None else means\n        stds = [(0.5, 0.5, 0.5)] if stds is None else stds\n\n        # TIMM `data_cfg` Parameters\n        self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds\n\n        # Grab torchvision transforms via TIMM =>> need to parse for specific \"functional\" transform values!\n        self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], []\n        self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None\n\n        for idx in range(len(input_sizes)):\n            transform = timm.data.create_transform(\n                input_size=self.input_sizes[idx],\n                interpolation=self.interpolations[idx],\n                mean=self.means[idx],\n                std=self.stds[idx],\n                crop_pct=1.0,  # Set to 1.0 to ignore cropping (initial Resize sets `input_size`)\n                crop_mode=\"center\",  # Default crop mode -- no-op when `crop_pct == 1.0`\n                is_training=False,  # No image augmentations when loading the transform!\n            )\n\n            # [Validation] Ensure appropriate transform structure, expected sizes\n            if not (\n                isinstance(transform, Compose)\n                and (len(transform.transforms) == 4)\n                and isinstance(transform.transforms[0], Resize)\n                and isinstance(transform.transforms[1], CenterCrop)\n                and isinstance(transform.transforms[2], ToTensor)\n                and isinstance(transform.transforms[3], Normalize)\n                and (transform.transforms[0].size == self.input_sizes[idx][-1])\n                and (transform.transforms[1].size == self.input_sizes[idx][-2:])\n            ):\n                raise ValueError(f\"Unexpected TIMM image transformation structure/sizes: `{transform}`\")\n\n            # HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.\n            #   => Instead, we're going to parse the transform and call \"torchvision.transforms.functional\" (`tvf`)\n            resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3]\n            self.tvf_resize_params.append(\n                {\n                    \"size\": resize_t.size,\n                    \"interpolation\": TVF.pil_modes_mapping[resize_t.interpolation],\n                    \"max_size\": None,\n                    \"antialias\": True,\n                }\n            )\n            self.tvf_crop_params.append({\"output_size\": crop_t.size})\n            self.tvf_normalize_params.append(\n                {\n                    \"mean\": norm_t.mean.float().numpy().tolist(),\n                    \"std\": norm_t.std.float().numpy().tolist(),\n                    \"inplace\": False,\n                }\n            )\n            self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None\n\n            # Handle Prismatic `image_resize_strategy`\n            if self.image_resize_strategy == \"resize-naive\":\n                self.tvf_resize_params[idx][\"size\"] = (resize_t.size, resize_t.size)\n            elif self.image_resize_strategy == \"letterbox\":\n                self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]])\n            elif self.image_resize_strategy == \"resize-crop\":\n                pass\n            else:\n                raise ValueError(f\"Image resize strategy `{self.image_resize_strategy}` is not supported!\")\n\n        # Dispatch **kwargs to super()\n        super().__init__(**kwargs)\n\n    def apply_transform(self, img: Image.Image) -> torch.Tensor:\n        \"\"\"Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])\"\"\"\n        if self.tvf_do_letterbox:\n            img = letterbox_pad_transform(img, self.tvf_letterbox_fill)\n\n        # [Contract] Fused Backbones expect \"channel-stacked\" inputs; we'll unpack on the model side!\n        imgs_t = []\n        for idx in range(len(self.input_sizes)):\n            img_idx = TVF.resize(img, **self.tvf_resize_params[idx])\n            img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx])\n            img_idx_t = TVF.to_tensor(img_idx)\n            img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx])\n            imgs_t.append(img_idx_t)\n\n        # [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0\n        img_t = torch.vstack(imgs_t)\n\n        return img_t\n\n    def preprocess(\n        self,\n        images: Union[Image.Image, List[Image.Image]],\n        return_tensors: Optional[Union[str, TensorType]] = None,\n        **_: str,\n    ) -> BatchFeature:\n        \"\"\"\n        Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we\n        explicitly only handle PIL.Image.Image instances for simplicity.\n\n        @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.\n        @param return_tensors: BatchFeature default Tensor format (e.g., \"pt\" for torch); if None, returns np.ndarray\n\n        @return: Instance of `transformers :: BatchFeature` with a single key \"pixel_values\"\n        \"\"\"\n        if not isinstance(images, list):\n            images = [images]\n\n        # Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into \"batched\" Tensor\n        pixel_values = torch.stack([self.apply_transform(img.convert(\"RGB\")) for img in images])\n\n        # Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert\n        return BatchFeature(data={\"pixel_values\": pixel_values.float().numpy()}, tensor_type=return_tensors)\n\n    def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:\n        return self.preprocess(images, **kwargs)\n\n\n# === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer ===\n#   =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py\nclass PrismaticProcessor(ProcessorMixin):\n    attributes: ClassVar[List[str]] = [\"image_processor\", \"tokenizer\"]\n    image_processor_class: str = \"AutoImageProcessor\"\n    tokenizer_class: str = \"AutoTokenizer\"\n\n    def __init__(\n        self,\n        image_processor: Optional[ImageProcessingMixin] = None,\n        tokenizer: Optional[PreTrainedTokenizerBase] = None,\n    ) -> None:\n        super().__init__(image_processor, tokenizer)\n\n    def __call__(\n        self,\n        text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],\n        images: Union[Image.Image, List[Image.Image]],\n        padding: Union[bool, str, PaddingStrategy] = False,\n        truncation: Optional[Union[bool, str, TruncationStrategy]] = None,\n        max_length: Optional[int] = None,\n        return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,\n    ) -> BatchFeature:\n        \"\"\"\n        Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer,\n        forwards images to PrismaticImageProcessor.\n\n        @param text: The (batch) of text to encode; must be a string or list of strings.\n        @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.\n        @param padding: Sequence padding strategy (if multiple specified) in < True = \"longest\" | \"max_length\" | False >\n        @param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified\n        @param max_length: Maximum length (in tokens) to truncate\n        @param return_tensors: Type of return tensors (usually \"pt\" or TensorType.PYTORCH)\n\n        @return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.\n        \"\"\"\n        pixel_values = self.image_processor(images, return_tensors=return_tensors)[\"pixel_values\"]\n        text_inputs = self.tokenizer(\n            text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length\n        )\n\n        # [Validate] Need same number of images and text inputs!\n        if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:\n            raise ValueError(\"Batch is malformed; expected same number of images and text inputs!\")\n\n        return BatchFeature(data={**text_inputs, \"pixel_values\": pixel_values})\n\n    # === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===\n    def batch_decode(\n        self,\n        sequences: Union[List[int], List[List[int]], torch.Tensor, Any],  # `Any` = np.ndarray | tf.Tensor\n        skip_special_tokens: bool = False,\n        clean_up_tokenization_spaces: Optional[bool] = None,\n        **kwargs: str,\n    ) -> List[str]:\n        return self.tokenizer.batch_decode(\n            sequences=sequences,\n            skip_special_tokens=skip_special_tokens,\n            clean_up_tokenization_spaces=clean_up_tokenization_spaces,\n            **kwargs,\n        )\n\n    def decode(\n        self,\n        token_ids: Union[int, List[int], torch.Tensor, Any],  # `Any` = np.ndarray | tf.Tensor\n        skip_special_tokens: bool = False,\n        clean_up_tokenization_spaces: Optional[bool] = None,\n        **kwargs: str,\n    ) -> str:\n        return self.tokenizer.decode(\n            token_ids=token_ids,\n            skip_special_tokens=skip_special_tokens,\n            clean_up_tokenization_spaces=clean_up_tokenization_spaces,\n            **kwargs,\n        )\n\n    @property\n    def model_input_names(self) -> List[str]:\n        tokenizer_input_names = self.tokenizer.model_input_names\n        image_processor_input_names = self.image_processor.model_input_names\n\n        return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))\n"
  },
  {
    "path": "siirl/models/embodied/openvla_oft/__init__.py",
    "content": ""
  },
  {
    "path": "siirl/models/embodied/openvla_oft/configuration_prismatic.py",
    "content": "\"\"\"\nconfiguration_prismatic.py\n\nHuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`.\nDefault configuration specifies `siglip-224px+7b`.\n\"\"\"\n\nfrom typing import Any, Dict, List, Optional\n\nfrom transformers import PretrainedConfig\nfrom transformers.models.auto import CONFIG_MAPPING\n\n# === Utilities for Mapping Prismatic names to HF names ===\n# fmt: off\nVISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = {\n    \"clip-vit-l\": [224], \"siglip-vit-so400m\": [224], \"dinov2-vit-l\": [224], \"in1k-vit-l\": [224],\n\n    \"clip-vit-l-336px\": [336],\n    \"siglip-vit-so400m-384px\": [384],\n\n    \"dinoclip-vit-l-336px\": [336, 336],\n    \"dinosiglip-vit-so-224px\": [224, 224],\n    \"dinosiglip-vit-so-384px\": [384, 384],\n}\nVISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = {\n    \"clip-vit-l\": [\"vit_large_patch14_clip_224.openai\"],\n    \"clip-vit-l-336px\": [\"vit_large_patch14_clip_336.openai\"],\n\n    \"dinov2-vit-l\": [\"vit_large_patch14_reg4_dinov2.lvd142m\"],\n    \"in1k-vit-l\": [\"vit_large_patch16_224.augreg_in21k_ft_in1k\"],\n\n    \"siglip-vit-so400m\": [\"vit_so400m_patch14_siglip_224\"],\n    \"siglip-vit-so400m-384px\": [\"vit_so400m_patch14_siglip_384\"],\n\n    \"dinoclip-vit-l-336px\": [\"vit_large_patch14_reg4_dinov2.lvd142m\", \"vit_large_patch14_clip_336.openai\"],\n    \"dinosiglip-vit-so-224px\": [\"vit_large_patch14_reg4_dinov2.lvd142m\", \"vit_so400m_patch14_siglip_224\"],\n    \"dinosiglip-vit-so-384px\": [\"vit_large_patch14_reg4_dinov2.lvd142m\", \"vit_so400m_patch14_siglip_384\"],\n}\nTIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = {\n    \"clip-vit-l\": [\"quick_gelu\"], \"clip-vit-l-336px\": [\"quick_gelu\"],\n    \"dinov2-vit-l\": [None], \"in1k-vit-l\": [None],\n    \"siglip-vit-so400m\": [None], \"siglip-vit-so400m-384px\": [None],\n    \"dinoclip-vit-l-336px\": [None, \"quick_gelu\"],\n    \"dinosiglip-vit-so-224px\": [None, None], \"dinosiglip-vit-so-384px\": [None, None]\n}\n\nLLM_BACKBONE_TO_HF_PATH = {\n    \"llama2-7b-pure\": \"meta-llama/Llama-2-7b-hf\", \"llama2-13b-pure\": \"meta-llama/Llama-2-13b-hf\",\n    \"llama2-7b-chat\": \"meta-llama/Llama-2-7b-chat-hf\", \"llama2-13b-chat\": \"meta-llama/Llama-2-13b-chat-hf\",\n\n    \"vicuna-v15-7b\": \"lmsys/vicuna-7b-v1.5\", \"vicuna-v15-13b\": \"lmsys/vicuna-13b-v1.5\",\n\n    \"mistral-v0.1-7b-pure\": \"mistralai/Mistral-7B-v0.1\",\n    \"mistral-v0.1-7b-instruct\": \"mistralai/Mistral-7B-Instruct-v0.1\",\n\n    \"phi-2-3b\": \"microsoft/phi-2\",\n}\nLLM_BACKBONE_TO_HF_METACLASS = {\n    \"llama2-7b-pure\": \"llama\", \"llama2-13b-pure\": \"llama\", \"llama2-7b-chat\": \"llama\", \"llama2-13b-chat\": \"llama\",\n    \"vicuna-v15-7b\": \"llama\", \"vicuna-v15-13b\": \"llama\",\n\n    \"mistral-v0.1-7b-pure\": \"mistral\", \"mistral-v0.1-7b-instruct\": \"mistral\",\n\n    \"phi-2-3b\": \"phi\",\n}\n\nVALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys())\nVALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH)\n# fmt: on\n\n\nclass PrismaticConfig(PretrainedConfig):\n    model_type: str = \"prismatic\"\n    is_composition: bool = False\n\n    def __init__(\n        self,\n        vision_backbone_id: str = \"siglip-vit-so400m\",\n        llm_backbone_id: str = \"vicuna-v15-7b\",\n        arch_specifier: str = \"no-align+gelu-mlp\",\n        use_fused_vision_backbone: Optional[bool] = None,\n        image_resize_strategy: str = \"letterbox\",\n        text_config: Optional[Dict[str, Any]] = None,\n        llm_max_length: int = 2048,\n        pad_token_id: int = 32000,\n        pad_to_multiple_of: int = 64,\n        output_projector_states: bool = False,\n        **kwargs: str,\n    ) -> None:\n        if vision_backbone_id not in VALID_VISION_BACKBONES:\n            raise ValueError(f\"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }\")\n\n        if llm_backbone_id not in VALID_LLM_BACKBONES:\n            raise ValueError(f\"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }\")\n\n        # Set Prismatic Configuration Fields\n        self.vision_backbone_id = vision_backbone_id\n        self.llm_backbone_id = llm_backbone_id\n        self.arch_specifier = arch_specifier\n        self.output_projector_states = output_projector_states\n\n        # [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing\n        self.use_fused_vision_backbone = (\n            use_fused_vision_backbone\n            if use_fused_vision_backbone is not None\n            else any(self.vision_backbone_id.startswith(v) for v in [\"dinoclip\", \"dinosiglip\"])\n        )\n\n        self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id]\n        self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id]\n        self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id]\n        self.image_resize_strategy = image_resize_strategy\n\n        self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id]\n        self.llm_max_length = llm_max_length\n        self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of\n\n        # [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming!\n        self.text_config = (\n            CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config)\n            if text_config is not None\n            else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]()\n        )\n\n        # Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well...\n        super().__init__(pad_token_id=pad_token_id, **kwargs)\n\n\nclass OpenVLAConfig(PrismaticConfig):\n    model_type: str = \"openvla\"\n\n    def __init__(\n        self,\n        norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None,\n        n_action_bins: int = 256,\n        **kwargs: str,\n    ) -> None:\n        self.norm_stats, self.n_action_bins = norm_stats, n_action_bins\n\n        super().__init__(**kwargs)\n"
  },
  {
    "path": "siirl/models/embodied/openvla_oft/constants.py",
    "content": "\"\"\"\nImportant constants for VLA training and evaluation.\n\nAttempts to automatically identify the correct constants to set based on the Python command used to launch\ntraining or evaluation. If it is unclear, defaults to using the LIBERO simulation benchmark constants.\n\"\"\"\nimport sys\nfrom enum import Enum\n\n# Llama 2 token constants\nIGNORE_INDEX = -100\nACTION_TOKEN_BEGIN_IDX = 31743\nSTOP_INDEX = 2  # '</s>'\n\n\n# Defines supported normalization schemes for action and proprioceptive state.\nclass NormalizationType(str, Enum):\n    # fmt: off\n    NORMAL = \"normal\"               # Normalize to Mean = 0, Stdev = 1\n    BOUNDS = \"bounds\"               # Normalize to Interval = [-1, 1]\n    BOUNDS_Q99 = \"bounds_q99\"       # Normalize [quantile_01, ..., quantile_99] --> [-1, ..., 1]\n    # fmt: on\n\n\n# Define constants for each robot platform\nLIBERO_CONSTANTS = {\n    \"NUM_ACTIONS_CHUNK\": 8,\n    \"ACTION_DIM\": 7,\n    \"PROPRIO_DIM\": 8,\n    \"ACTION_PROPRIO_NORMALIZATION_TYPE\": NormalizationType.BOUNDS_Q99,\n}\n\nALOHA_CONSTANTS = {\n    \"NUM_ACTIONS_CHUNK\": 25,\n    \"ACTION_DIM\": 14,\n    \"PROPRIO_DIM\": 14,\n    \"ACTION_PROPRIO_NORMALIZATION_TYPE\": NormalizationType.BOUNDS,\n}\n\nBRIDGE_CONSTANTS = {\n    \"NUM_ACTIONS_CHUNK\": 5,\n    \"ACTION_DIM\": 7,\n    \"PROPRIO_DIM\": 7,\n    \"ACTION_PROPRIO_NORMALIZATION_TYPE\": NormalizationType.BOUNDS_Q99,\n}\n\n\n# Function to detect robot platform from command line arguments\ndef detect_robot_platform():\n    cmd_args = \" \".join(sys.argv).lower()\n\n    if \"libero\" in cmd_args:\n        return \"LIBERO\"\n    elif \"aloha\" in cmd_args:\n        return \"ALOHA\"\n    elif \"bridge\" in cmd_args:\n        return \"BRIDGE\"\n    else:\n        # Default to LIBERO if unclear\n        return \"LIBERO\"\n\n\n# Determine which robot platform to use\nROBOT_PLATFORM = detect_robot_platform()\n\n# Set the appropriate constants based on the detected platform\nif ROBOT_PLATFORM == \"LIBERO\":\n    constants = LIBERO_CONSTANTS\nelif ROBOT_PLATFORM == \"ALOHA\":\n    constants = ALOHA_CONSTANTS\nelif ROBOT_PLATFORM == \"BRIDGE\":\n    constants = BRIDGE_CONSTANTS\n\n# Assign constants to global variables\nNUM_ACTIONS_CHUNK = constants[\"NUM_ACTIONS_CHUNK\"]\nACTION_DIM = constants[\"ACTION_DIM\"]\nPROPRIO_DIM = constants[\"PROPRIO_DIM\"]\nACTION_PROPRIO_NORMALIZATION_TYPE = constants[\"ACTION_PROPRIO_NORMALIZATION_TYPE\"]\n\n# Print which robot platform constants are being used (for debugging)\nprint(f\"Using {ROBOT_PLATFORM} constants:\")\nprint(f\"  NUM_ACTIONS_CHUNK = {NUM_ACTIONS_CHUNK}\")\nprint(f\"  ACTION_DIM = {ACTION_DIM}\")\nprint(f\"  PROPRIO_DIM = {PROPRIO_DIM}\")\nprint(f\"  ACTION_PROPRIO_NORMALIZATION_TYPE = {ACTION_PROPRIO_NORMALIZATION_TYPE}\")\nprint(\"If needed, manually set the correct constants in `prismatic/vla/constants.py`!\")\n"
  },
  {
    "path": "siirl/models/embodied/openvla_oft/modeling_prismatic.py",
    "content": "\"\"\"\nmodeling_prismatic.py\n\nCore HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions.\nInherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained,\nbut exactly replicate the logic in `prismatic.models.vlms.prismatic.py`.\n\"\"\"\n\nimport logging\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport timm\nimport tokenizers\nimport torch\nimport torch.nn as nn\nimport transformers\nfrom timm.models.vision_transformer import LayerScale\nfrom transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel\nfrom transformers.modeling_outputs import ModelOutput\n\nfrom .train_utils import (\n    get_current_action_mask,\n    get_next_actions_mask,\n)\nfrom .constants import (\n    ACTION_DIM,\n    ACTION_PROPRIO_NORMALIZATION_TYPE,\n    ACTION_TOKEN_BEGIN_IDX,\n    IGNORE_INDEX,\n    NUM_ACTIONS_CHUNK,\n    STOP_INDEX,\n    NormalizationType,\n)\n\nfrom .configuration_prismatic import OpenVLAConfig, PrismaticConfig\n\n# Set up logger\nlogger = logging.getLogger(__name__)\n\n\n# === Utility Functions for Monkey-Patching ===\ndef unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:\n    def wrapper(*args: Any, **kwargs: Any) -> Any:\n        result = fn(*args, **kwargs)\n        return result[0] if isinstance(result, tuple) else result\n\n    return wrapper\n\n\n# HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.\n#   =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109\n#   =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960\ndef _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:\n    return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor\n\n\ndef ls_apply_patch(ls_module: LayerScale):\n    ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())\n    ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)\n    del ls_module.gamma\n\n\n# === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===\nclass PrismaticVisionBackbone(nn.Module):\n    \"\"\"\n    Vision backbone for Prismatic models that handles image feature extraction.\n\n    Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations.\n    For fused backbones, features from both models are concatenated along the feature dimension.\n    \"\"\"\n\n    def __init__(\n        self,\n        use_fused_vision_backbone: bool,\n        image_sizes: List[int],\n        timm_model_ids: List[str],\n        timm_override_act_layers: List[Optional[str]],\n    ) -> None:\n        \"\"\"\n        Initialize the vision backbone.\n\n        Args:\n            use_fused_vision_backbone: Whether to use two backbones and fuse their features\n            image_sizes: List of image sizes for each backbone\n            timm_model_ids: List of TIMM model IDs to use for each backbone\n            timm_override_act_layers: List of activation layer overrides for each backbone\n        \"\"\"\n        super().__init__()\n        self.use_fused_vision_backbone = use_fused_vision_backbone\n        self.num_images_in_input = 1  # Default value, can be overridden later\n\n        # Validate number of (fused) vision backbones\n        if len(timm_model_ids) > 2:\n            raise ValueError(\"Prismatic models only support up to 2 (fused) vision backbones!\")\n\n        # Create primary featurizer\n        self.featurizer = self._create_featurizer(\n            model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0]\n        )\n        self.embed_dim = self.featurizer.embed_dim\n\n        # Create secondary featurizer if using fused backbone\n        if self.use_fused_vision_backbone:\n            self.fused_featurizer = self._create_featurizer(\n                model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1]\n            )\n            self.embed_dim += self.fused_featurizer.embed_dim\n\n        # Patch LayerScale modules for HF compatibility\n        self._patch_layer_scales()\n\n    def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module:\n        \"\"\"\n        Create a TIMM-based featurizer model with appropriate configurations.\n\n        Args:\n            model_id: The TIMM model ID to load\n            img_size: Input image size for the model\n            act_layer: Override for the activation layer type\n\n        Returns:\n            A configured featurizer model\n        \"\"\"\n        featurizer = timm.create_model(\n            model_id,\n            pretrained=False,\n            num_classes=0,\n            img_size=img_size,\n            act_layer=act_layer,\n        )\n\n        # Monkey-patch the forward function to extract the second-to-last layer features\n        num_blocks = len(featurizer.blocks)\n        featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2}))\n\n        return featurizer\n\n    def _patch_layer_scales(self) -> None:\n        \"\"\"\n        Patch all LayerScale modules to be compatible with HF's parameter naming.\n\n        HF Transformers overwrites parameters with names containing 'gamma',\n        so we need to rename and modify the forward method.\n        \"\"\"\n        # Patch primary featurizer\n        for module in self.featurizer.modules():\n            if isinstance(module, LayerScale):\n                ls_apply_patch(module)\n\n        # Patch secondary featurizer if it exists\n        if self.use_fused_vision_backbone:\n            for module in self.fused_featurizer.modules():\n                if isinstance(module, LayerScale):\n                    ls_apply_patch(module)\n\n    def get_num_patches(self) -> int:\n        \"\"\"\n        Returns the number of vision patches output by the vision backbone.\n\n        Returns:\n            Number of patches per image\n        \"\"\"\n        return self.featurizer.patch_embed.num_patches\n\n    def get_num_images_in_input(self) -> int:\n        \"\"\"\n        Returns the number of input images for the vision backbone.\n\n        Returns:\n            Number of images expected in the input\n        \"\"\"\n        return self.num_images_in_input\n\n    def set_num_images_in_input(self, num_images_in_input: int) -> None:\n        \"\"\"\n        Sets the number of input images for the vision backbone.\n\n        Args:\n            num_images_in_input: Number of images to expect in the input\n        \"\"\"\n        self.num_images_in_input = num_images_in_input\n\n    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Implements the forward pass for the vision backbone.\n\n        If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features\n        (otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone).\n\n        Args:\n            pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W).\n        \"\"\"\n        if self.num_images_in_input == 1:\n            if not self.use_fused_vision_backbone:\n                return self.featurizer(pixel_values)\n\n            # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack\n            img, img_fused = torch.split(pixel_values, [3, 3], dim=1)\n            patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)\n\n            return torch.cat([patches, patches_fused], dim=2)\n\n        else:\n            assert self.use_fused_vision_backbone, \"Multi-image inputs require using fused backbone!\"\n\n            # Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2)\n            images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1)\n\n            # Process each image and collect patches\n            all_patches = []\n            for img in images:\n                # Split each image further into two stacks of channels (each with 3 channels)\n                img_regular, img_fused = torch.split(img, [3, 3], dim=1)\n\n                # Get patches from both SigLIP and DINOv2 vision transformers\n                patches = self.featurizer(img_regular)\n                patches_fused = self.fused_featurizer(img_fused)\n\n                # Concatenate SigLIP and DINOv2 patches along the hidden dimension\n                combined_patches = torch.cat([patches, patches_fused], dim=2)\n                all_patches.append(combined_patches)\n\n            # Concatenate all patches along the patch dimension\n            return torch.cat(all_patches, dim=1)\n\n\n# === Prismatic Projector (nn.Module) Definitions ===\nclass PrismaticProjector(nn.Module):\n    def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:\n        super().__init__()\n        self.use_fused_vision_backbone = use_fused_vision_backbone\n        self.vision_dim, self.llm_dim = vision_dim, llm_dim\n\n        # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!\n        if not self.use_fused_vision_backbone:\n            self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)\n            self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)\n            self.act_fn1 = nn.GELU()\n        else:\n            initial_projection_dim = 4 * vision_dim\n            self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)\n            self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)\n            self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)\n            self.act_fn1 = nn.GELU()\n            self.act_fn2 = nn.GELU()\n\n    def forward(self, img_patches: torch.Tensor) -> torch.Tensor:\n        if not self.use_fused_vision_backbone:\n            projected_features = self.fc1(img_patches)\n            projected_features = self.act_fn1(projected_features)\n            projected_features = self.fc2(projected_features)\n        else:\n            projected_features = self.fc1(img_patches)\n            projected_features = self.act_fn1(projected_features)\n            projected_features = self.fc2(projected_features)\n            projected_features = self.act_fn2(projected_features)\n            projected_features = self.fc3(projected_features)\n\n        return projected_features\n\n\n# === Main HF Class Definitions ===\n@dataclass\nclass PrismaticCausalLMOutputWithPast(ModelOutput):\n    \"\"\"Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features.\"\"\"\n\n    loss: Optional[torch.FloatTensor] = None\n    logits: torch.FloatTensor = None\n    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None\n    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None\n    attentions: Optional[Tuple[torch.FloatTensor]] = None\n\n    # Additions for VLMs\n    projector_features: Optional[torch.FloatTensor] = None\n\n\nclass PrismaticPreTrainedModel(PreTrainedModel):\n    config_class: PretrainedConfig = PrismaticConfig\n    base_model_prefix: str = \"model\"\n    supports_gradient_checkpointing: bool = True\n\n    _no_split_modules: ClassVar[List[str]] = [\"PrismaticProjector\"]\n    _skip_keys_device_placement: str = \"past_key_values\"\n    _supports_flash_attn_2: bool = True\n\n    def _init_weights(self, module: nn.Module) -> None:\n        # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!\n        #   => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at\n        #      https://github.com/TRI-ML/prismatic-vlms\n        std = (\n            self.config.initializer_range\n            if hasattr(self.config, \"initializer_range\")\n            else self.config.text_config.initializer_range\n        )\n\n        if hasattr(module, \"class_embedding\"):\n            module.class_embedding.data.normal_(mean=0.0, std=std)\n\n        if isinstance(module, (nn.Linear, nn.Conv2d)):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n    @property\n    def _supports_sdpa(self) -> bool:\n        \"\"\"Check LLM supports SDPA Attention\"\"\"\n        return self.language_model._supports_sdpa\n\n\nclass PrismaticForConditionalGeneration(PrismaticPreTrainedModel):\n    def __init__(self, config: PrismaticConfig) -> None:\n        super().__init__(config)\n\n        # [Validation] Lightweight Validate on `config` Fields + Dependency Versions\n        if config.use_fused_vision_backbone is None:\n            raise ValueError(\"Missing config field `use_fused_vision_backbone`\")\n\n        if timm.__version__ not in {\"0.9.10\", \"0.9.11\", \"0.9.12\", \"0.9.16\"}:\n            raise NotImplementedError(\n                \"TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue \"\n                \"if you urgently need support for latest TIMM versions.\"\n            )\n\n        if (transformers.__version__ != \"4.40.1\") or (tokenizers.__version__ != \"0.19.1\"):\n            logger.warning(\n                f\"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got \"\n                f\"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; \"\n                f\"there might be inference-time regressions due to dependency changes. If in doubt, please\"\n                f\"use the above versions.\"\n            )\n\n        # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)\n        self.vision_backbone = PrismaticVisionBackbone(\n            config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers\n        )\n\n        # Create Multimodal Projector\n        self.projector = PrismaticProjector(\n            config.use_fused_vision_backbone,\n            vision_dim=self.vision_backbone.embed_dim,\n            llm_dim=config.text_config.hidden_size,\n        )\n\n        # Instantiate LLM Backbone\n        self.language_model = AutoModelForCausalLM.from_config(\n            config.text_config, attn_implementation=config._attn_implementation\n        )\n        self.vocab_size = config.text_config.vocab_size\n        self.pad_token_id = config.pad_token_id\n        self.llm_dim = config.text_config.hidden_size\n\n        # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing\n        self.post_init()\n\n    # === `PreTrainedModel` Boilerplate ===\n    def get_input_embeddings(self) -> nn.Module:\n        return self.language_model.get_input_embeddings()\n\n    def set_input_embeddings(self, value: nn.Module) -> None:\n        self.language_model.set_input_embeddings(value)\n\n    def get_output_embeddings(self) -> nn.Module:\n        return self.language_model.get_output_embeddings()\n\n    def set_output_embeddings(self, new_embeddings: nn.Module) -> None:\n        self.language_model.set_output_embeddings(new_embeddings)\n\n    def get_decoder(self) -> nn.Module:\n        return self.language_model.get_decoder()\n\n    def set_decoder(self, decoder: nn.Module) -> None:\n        self.language_model.set_decoder(decoder)\n\n    def tie_weights(self) -> None:\n        self.language_model.tie_weights()  # Note: `Llama-2` and `Mistral` don't tie weights (no-op)\n\n    def resize_token_embeddings(\n        self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None\n    ) -> nn.Embedding:\n        updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)\n\n        # Update config/instance variables\n        self.config.text_config.vocab_size = updated_embeddings.num_embeddings\n        self.vocab_size = updated_embeddings.num_embeddings\n\n        return updated_embeddings\n\n    def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features):\n        \"\"\"\n        Replace embeddings in input_embeddings at positions where all_actions_mask is True\n        with embeddings from noisy_action_features, using vectorized operations.\n\n        Args:\n            input_embeddings: Tensor of shape (B, S, D)\n            all_actions_mask: Boolean tensor of shape (B, S)\n            noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample\n\n        Returns:\n            Modified input_embeddings tensor\n        \"\"\"\n        # Clone input to avoid modifying the original tensor\n        new_input_embeddings = input_embeddings.clone()\n\n        # Create a tensor with the same shape of input_embeddings to hold the noisy action features\n        repositioned_noisy_action_features = torch.zeros_like(input_embeddings)\n\n        # Create batch indices for splicing\n        batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device)\n        batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1])\n\n        # Get indices where mask is True for each sample\n        masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask])\n\n        # Move the noisy action features into their correct positions\n        repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features\n\n        # Combine original input embeddings and noisy action embeddings using the mask\n        new_input_embeddings = torch.where(\n            all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings\n        )\n\n        return new_input_embeddings\n\n    def _process_action_masks(self, labels):\n        \"\"\"Helper to get action masks from labels\"\"\"\n        current_action_mask = get_current_action_mask(labels)\n        next_actions_mask = get_next_actions_mask(labels)\n        all_actions_mask = current_action_mask | next_actions_mask  # (B, seq_len)\n        return all_actions_mask\n\n    def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False):\n        \"\"\"Process vision features with optional FiLM conditioning\"\"\"\n        if use_film:\n            # FiLM: Infuse language inputs into visual features\n            patch_features = self.vision_backbone(pixel_values, language_embeddings)  # (bsz, 256 * num_images, D)\n        else:\n            patch_features = self.vision_backbone(pixel_values)  # (bsz, 256 * num_images, D)\n\n        # Project patch embeddings into language embedding space\n        return self.projector(patch_features)\n\n    def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector):\n        \"\"\"Process proprioceptive features and append to vision features\"\"\"\n        if proprio_projector is not None and proprio is not None:\n            # projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim)\n            # proprio: (bsz, proprio_dim) or (propro_dim,)\n            proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1)  # (bsz, proprio_dim)\n            proprio_features = proprio_projector(proprio)  # (bsz, llm_dim)\n            proprio_features = proprio_features.unsqueeze(dim=1)  # (bsz, 1, llm_dim)\n            # For simplicity, just append proprio token to the end of projected vision patch tokens\n            return torch.cat((projected_patch_embeddings, proprio_features), dim=1)\n        return projected_patch_embeddings\n\n    def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):\n        \"\"\"Build multimodal embeddings and attention mask\"\"\"\n        # Update attention mask\n        projected_patch_attention_mask = None\n        if attention_mask is not None:\n            projected_patch_attention_mask = torch.full(\n                (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),\n                fill_value=True,\n                dtype=attention_mask.dtype,\n                device=attention_mask.device,\n            )\n\n        # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)\n        multimodal_embeddings = torch.cat(\n            [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1\n        )\n\n        multimodal_attention_mask = None\n        if attention_mask is not None:\n            multimodal_attention_mask = torch.cat(\n                [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1\n            )\n\n        return multimodal_embeddings, multimodal_attention_mask\n\n    def _build_multimodal_labels(self, labels, projected_patch_embeddings):\n        \"\"\"Build multimodal labels with IGNORE_INDEX for patch embeddings\"\"\"\n        if labels is not None:\n            projected_patch_labels = torch.full(\n                (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),\n                fill_value=IGNORE_INDEX,\n                dtype=labels.dtype,\n                device=labels.device,\n            )\n            return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)\n        return None\n\n    # === Core Prismatic VLM `forward()` Logic ===\n    # def forward(\n    #     self,\n    #     input_ids: Optional[torch.LongTensor] = None,\n    #     attention_mask: Optional[torch.Tensor] = None,\n    #     pixel_values: Optional[torch.FloatTensor] = None,\n    #     labels: Optional[torch.LongTensor] = None,\n    #     inputs_embeds: Optional[torch.FloatTensor] = None,\n    #     past_key_values: Optional[List[torch.FloatTensor]] = None,\n    #     use_cache: Optional[bool] = None,\n    #     output_attentions: Optional[bool] = None,\n    #     output_hidden_states: Optional[bool] = None,\n    #     output_projector_features: Optional[bool] = None,\n    #     return_dict: Optional[bool] = None,\n    #     proprio=None,\n    #     proprio_projector=None,\n    #     noisy_actions=None,\n    #     noisy_action_projector=None,\n    #     diffusion_timestep_embeddings=None,\n    #     use_film: bool = False,\n    # ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:\n    #     \"\"\"Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance.\"\"\"\n    #     output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n    #     output_hidden_states = (\n    #         output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n    #     )\n    #     output_projector_features = output_projector_features if output_projector_features is not None else False\n    #     return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n    #     # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)\n    #     use_cache = use_cache and not self.training\n\n    #     # Instantiate Placeholder for Projector Features\n    #     projected_patch_embeddings = None\n\n    #     # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===\n    #     if input_ids.shape[1] == 1:\n    #         assert input_ids.shape[0] == 1, \"Generation is only currently supported for batch size of 1!\"\n    #         assert past_key_values is not None, \"You must provide `past_key_values` during cached generation!\"\n    #         assert labels is None, \"Unexpected key `labels` provided during cached generation!\"\n\n    #         language_model_output = self.language_model(\n    #             input_ids=input_ids,\n    #             attention_mask=None,\n    #             position_ids=None,\n    #             past_key_values=past_key_values,\n    #             inputs_embeds=None,\n    #             labels=None,\n    #             use_cache=use_cache,\n    #             output_attentions=output_attentions,\n    #             output_hidden_states=output_hidden_states,\n    #             return_dict=return_dict,\n    #         )\n\n    #     # === Handle Unimodal Forward ===\n    #     elif pixel_values is None:\n    #         assert (input_ids is not None) and (inputs_embeds is None), \"Missing `input_ids` in language-only forward!\"\n    #         assert past_key_values is None, \"Unexpected key `past_key_values` provided during language-only forward!\"\n\n    #         language_model_output = self.language_model(\n    #             input_ids=input_ids,\n    #             attention_mask=attention_mask,\n    #             position_ids=None,\n    #             past_key_values=None,\n    #             inputs_embeds=None,\n    #             labels=labels,\n    #             use_cache=use_cache,\n    #             output_attentions=output_attentions,\n    #             output_hidden_states=output_hidden_states,\n    #             return_dict=return_dict,\n    #         )\n\n    #     # === Handle Multimodal Forward ===\n    #     elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):\n    #         assert past_key_values is None, \"Unexpected key `past_key_values` provided during multimodal forward!\"\n            \n    #         #test\n    #         \n    #         #test end\n                \n    #         # Get input embeddings (from language model embeddings)\n    #         input_embeddings = self.get_input_embeddings()(input_ids)  # (B, seq_len, D)\n\n    #         # Extract action masks\n    #         all_actions_mask = self._process_action_masks(labels)\n\n    #         # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)\n    #         language_embeddings = input_embeddings[~all_actions_mask].reshape(\n    #             input_embeddings.shape[0], -1, input_embeddings.shape[2]\n    #         )  # (B, lang_seq_len, llm_dim)\n\n    #         # Get visual features\n    #         projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)\n\n    #         # Add proprioceptive state if provided\n    #         projected_patch_embeddings = self._process_proprio_features(\n    #             projected_patch_embeddings, proprio, proprio_projector\n    #         )\n\n    #         # [Diffusion] Add diffusion timestep embedding if provided\n    #         if diffusion_timestep_embeddings is not None:\n    #             # For simplicity, just append diffusion timestep embedding to the end of projected vision patch tokens\n    #             projected_patch_embeddings = torch.cat(\n    #                 (projected_patch_embeddings, diffusion_timestep_embeddings), dim=1\n    #             )\n\n    #         # Process action embeddings\n    #         if noisy_actions is not None:\n    #             # Get mask corresponding to all action tokens\n    #             all_actions_mask = self._process_action_masks(labels)\n\n    #             # Reshape noisy actions into individual action tokens\n    #             # noisy_actions: (B, chunk_len, action_dim) -> (B, chunk_len * action_dim, 1)\n    #             B = noisy_actions.shape[0]\n    #             noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1)\n\n    #             # Project noisy action tokens into language model embedding space\n    #             noisy_action_features = noisy_action_projector(noisy_actions)  # (B, chunk_len * action_dim, llm_dim)\n\n    #             # Replace embeddings of the action tokens with noisy action embeddings\n    #             input_embeddings = self._replace_input_embeddings(\n    #                 input_embeddings, all_actions_mask, noisy_action_features\n    #             )\n    #         else:\n    #             # Replace the embeddings of the action tokens with zeros\n    #             # (Later on, the positional embeddings will be added to them)\n    #             all_actions_mask = all_actions_mask.unsqueeze(-1)  # (B, seq_len, 1)\n    #             input_embeddings = input_embeddings * ~all_actions_mask\n\n    #         # Build multimodal embeddings & attention mask\n    #         multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(\n    #             input_embeddings, projected_patch_embeddings, attention_mask\n    #         )\n\n    #         # Build labels for multimodal sequence if needed\n    #         multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)\n\n    #         # Dispatch to language model\n    #         language_model_output = self.language_model(\n    #             input_ids=None,\n    #             attention_mask=multimodal_attention_mask,\n    #             position_ids=None,\n    #             past_key_values=None,\n    #             inputs_embeds=multimodal_embeddings,\n    #             labels=multimodal_labels,\n    #             use_cache=use_cache,\n    #             output_attentions=output_attentions,\n    #             output_hidden_states=output_hidden_states,\n    #             return_dict=return_dict,\n    #         )\n\n    #     # === Otherwise =>> Assume Invalid! ===\n    #     elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):\n    #         raise ValueError(\"Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!\")\n\n    #     else:\n    #         raise ValueError(\n    #             \"Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\\n\"\n    #             f\"=> `input_ids` = {input_ids is not None}\\n\"\n    #             f\"=> `attention_mask` = {attention_mask is not None}\\n\"\n    #             f\"=> `pixel_values` = {pixel_values is not None}\\n\"\n    #             f\"=> `labels` = {labels is not None}\\n\"\n    #             f\"=> `input_embeds` = {inputs_embeds is not None}\\n\"\n    #             f\"=> `past_key_values` = {past_key_values is not None}\\n\"\n    #             f\"=> `use_cache` = {use_cache}\"\n    #         )\n\n    #     # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)\n    #     if not return_dict:\n    #         if output_projector_features and (projected_patch_embeddings is not None):\n    #             return *language_model_output, projected_patch_embeddings\n\n    #         return language_model_output\n\n    #     return PrismaticCausalLMOutputWithPast(\n    #         loss=language_model_output.loss,\n    #         logits=language_model_output.logits,\n    #         past_key_values=language_model_output.past_key_values,\n    #         hidden_states=language_model_output.hidden_states,\n    #         attentions=language_model_output.attentions,\n    #         projector_features=projected_patch_embeddings,\n    #     )\n\n    # === GenerationMixin Methods ===\n    def prepare_inputs_for_generation(\n        self,\n        input_ids: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        pixel_values: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        **kwargs: str,\n    ) -> Dict[str, torch.Tensor]:\n        \"\"\"Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic.\"\"\"\n        if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (\n            (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)\n        ):\n            raise ValueError(\"Generation with batch size > 1 is not currently supported!\")\n\n        # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens\n        if past_key_values is not None:\n            input_ids = input_ids[:, -1:]\n\n        # If `input_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"input_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        # Make sure `pixel_values` are preserved in `model_inputs`\n        model_inputs.update(\n            {\n                \"attention_mask\": attention_mask,\n                \"pixel_values\": pixel_values,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n            }\n        )\n\n        return model_inputs\n\n    # Defer to Language Model (all handle this differently, with different return types)\n    def _reorder_cache(self, *args, **kwargs) -> Any:\n        return self.language_model._reorder_cache(*args, **kwargs)\n    \n    def _prepare_input_for_action_prediction_verl(self, input_ids, attention_mask):\n        \"\"\"Prepares input for action prediction by adding necessary tokens\"\"\"\n        # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens\n        placeholder_action_token_ids = (\n            torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)\n        )\n        input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)\n\n        # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)\n        stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX\n        input_ids = torch.cat([input_ids, stop_token_id], dim=-1)\n\n        # Extend the attention mask to fit the new shape of input\n        # Note: Only batch size == 1 supported right now\n        mask_extension = (\n            torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))\n            .to(attention_mask.device)\n            .to(attention_mask.dtype)\n        )\n        attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)\n\n        return input_ids, attention_mask\n\n    def _prepare_labels_for_action_prediction_verl(self, labels, input_ids):\n        \"\"\"Creates labels tensor for action prediction if not provided\"\"\"\n        # Extend labels tensor with fake action labels\n        ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1\n        labels_extension = (\n            torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)\n            * ARBITRARY_ACTION_TOKEN_IDX\n        )\n        labels = torch.cat([labels, labels_extension], dim=-1)\n\n        # Replace last label token with stop token\n        labels[:, -1] = STOP_INDEX\n\n        return labels\n    \n    def _verl_discrete_compute_logits(\n        self,\n        input_embeddings,\n        all_actions_mask,\n        projected_patch_embeddings,\n        attention_mask,\n        labels,\n        NUM_PATCHES,\n        NUM_PROMPT_TOKENS,\n        action_head=None,\n    ):#contintue!!!!!\n        \"\"\"Run L1 regression-based continuous action prediction or discrete action tokens prediction.\"\"\"\n        # Zero out action token embeddings\n        all_actions_mask = all_actions_mask.unsqueeze(-1)  # (B, seq_len, 1)\n        input_embeddings = input_embeddings * ~all_actions_mask\n\n        # Build multimodal embeddings and attention mask\n        multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(\n            input_embeddings, projected_patch_embeddings, attention_mask\n        )\n\n        # Forward pass through language model\n        language_model_output = self.language_model(\n            input_ids=None,\n            attention_mask=multimodal_attention_mask,\n            position_ids=None,\n            past_key_values=None,\n            inputs_embeds=multimodal_embeddings,\n            labels=None,\n            use_cache=None,\n            output_attentions=False,\n            output_hidden_states=False,\n            return_dict=True,\n        )\n\n        # Extract hidden states for action tokens\n        #last_hidden_states = language_model_output.hidden_states[-1]  # (B, seq_len, D)\n        # actions_hidden_states = last_hidden_states[\n        #     :,\n        #     NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,\n        #     :,\n        # ]  # (B, act_chunk_len, D)\n\n        # Handle different prediction methods\n        # if action_head is not None:\n        #     # L1 regression prediction\n        #     normalized_actions = action_head.predict_action(actions_hidden_states)\n        #     normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)\n        #     normalized_actions = normalized_actions.float().cpu().detach().numpy()\n        # else:\n        # Discrete token-based prediction\n      \n        compute_logits = language_model_output.logits[\n                    :,\n                    NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,\n                ]\n            \n        return  compute_logits\n    \n    # def forward(\n    #     self,\n    #     input_ids: Optional[torch.LongTensor] = None,\n    #     unnorm_key: Optional[str] = None,\n    #     proprio=None,\n    #     proprio_projector=None,\n    #     action_head=None,\n    #     noisy_action_projector=None,\n    #     use_film: bool = False,\n    #     **kwargs: str,\n    # ) :\n    #     \"\"\"Predict actions from input sequence, with options for different prediction methods.\n\n    #     Args:\n    #         input_ids: Input token ids\n    #         unnorm_key: Key for unnormalization statistics\n    #         proprio: Proprioceptive features\n    #         proprio_projector: Projector for proprioceptive features\n    #         action_head: Optional head for L1 regression or diffusion-based prediction\n    #         noisy_action_projector: Projector for noisy actions in diffusion-based prediction\n    #         use_film: Whether to use FiLM conditioning\n    #         **kwargs: Additional arguments including pixel_values and attention_mask\n\n    #     Returns:\n    #         Tuple of (unnormalized_actions, action_hidden_states)\n    #     \"\"\"\n    #     # If the special empty token ('') does not already appear after the colon (':') token in the prompt\n    #     # (after \"OUT:\" or \"ASSISTANT:\"), insert it to match the inputs seen at training time\n    #     # if not torch.all(input_ids[:, -1] == 29871):\n    #     #     input_ids = torch.cat(\n    #     #         (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1\n    #     #     )\n    #     #print(\"!!!!!!!!!!!!!!Entering forward!!!!!!!!!!\")\n    #     pixel_values = kwargs[\"pixel_values\"]\n    #     attention_mask = kwargs[\"attention_mask\"]\n        \n    #     # Create fake labels tensor (needed for action mask)\n    #     labels = input_ids.clone()\n    #     labels[:] = IGNORE_INDEX\n\n    #     # Get number of tokens in prompt (excluding the start token)\n    #     NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1  # Subtract action tokens and stop token\n\n    #     # Prepare inputs by adding necessary tokens\n    #     #input_ids, attention_mask = self._prepare_input_for_action_prediction_verl(input_ids, attention_mask)\n        \n    #     #test\n    #     placeholder_action_token_ids = (\n    #         torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)\n    #     )\n    #     input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)\n\n    #     # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)\n    #     stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX\n    #     input_ids = torch.cat([input_ids, stop_token_id], dim=-1)\n\n    #     # Extend the attention mask to fit the new shape of input\n    #     # Note: Only batch size == 1 supported right now\n    #     mask_extension = (\n    #         torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))\n    #         .to(attention_mask.device)\n    #         .to(attention_mask.dtype)\n    #     )\n    #     attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)\n\n    #     #return input_ids, attention_mask\n        \n    #     #test end\n        \n\n    #     # Update labels tensor for action mask computation later\n    #     #labels = self._prepare_labels_for_action_prediction_verl(labels, input_ids)\n    #     #test \n        \n    #     ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1\n    #     labels_extension = (\n    #         torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)\n    #         * ARBITRARY_ACTION_TOKEN_IDX\n    #     )\n    #     labels = torch.cat([labels, labels_extension], dim=-1)\n\n    #     # Replace last label token with stop token\n    #     labels[:, -1] = STOP_INDEX\n\n    #     #return labels\n        \n    #     #test ed\n       \n\n    #     # Get input embeddings and action masks\n        \n        \n        \n    #     input_embeddings = self.get_input_embeddings()(input_ids)\n        \n        \n    #     #all_actions_mask = self._process_action_masks(labels)\n    #     #test\n    #     #current_action_mask = get_current_action_mask(labels)\n    #     newline_positions = labels != IGNORE_INDEX\n\n    #     # Calculate cumulative sum to identify regions between newlines\n    #     cumsum = torch.cumsum(newline_positions, dim=1)\n\n    #     # Create the mask\n    #     mask = (1 <= cumsum) & (cumsum <= ACTION_DIM)\n\n    #     # Extract the action part only\n    #     action_tokens_only_mask = labels > ACTION_TOKEN_BEGIN_IDX\n    #     current_action_mask = action_tokens_only_mask * mask\n\n    #     #next_actions_mask = get_next_actions_mask(labels)\n    #     newline_positions = labels != IGNORE_INDEX\n\n    #     # Calculate cumulative sum to identify regions between newlines\n    #     cumsum = torch.cumsum(newline_positions, dim=1)\n\n    #     # Create the mask\n    #     mask = cumsum > ACTION_DIM\n\n    #     # Extract the action part only\n    #     action_tokens_only_mask = labels > ACTION_TOKEN_BEGIN_IDX\n    #     next_actions_mask = action_tokens_only_mask * mask\n        \n    #     all_actions_mask = current_action_mask | next_actions_mask  # (B, seq_len)\n        \n    #     #test end\n        \n    #     # Extract language embeddings\n    #     language_embeddings = input_embeddings[~all_actions_mask].reshape(\n    #         input_embeddings.shape[0], -1, input_embeddings.shape[2]\n    #     )\n\n    #     # Process vision features\n    #     #projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)\n    #     #test\n    #     if use_film:\n    #         # FiLM: Infuse language inputs into visual features\n    #         raise ValueError\n    #         patch_features = self.vision_backbone(pixel_values, language_embeddings)  # (bsz, 256 * num_images, D)\n    #     else:\n    #         patch_features = self.vision_backbone(pixel_values)  # (bsz, 256 * num_images, D)\n\n    #     projected_patch_embeddings = self.projector(patch_features)\n    #     #test end\n        \n        \n    #     # Add proprioceptive features if provided\n    #     use_proprio = proprio_projector is not None and proprio is not None\n    #     if use_proprio:\n    #         proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)\n    #         projected_patch_embeddings = self._process_proprio_features(\n    #             projected_patch_embeddings, proprio, proprio_projector\n    #         )\n\n    #     # Use diffusion if provided, otherwise use regression or discrete prediction\n    #     use_diffusion = noisy_action_projector is not None and hasattr(action_head, \"noise_scheduler\")\n\n    #     # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)\n    #     NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()\n    #     if use_proprio:\n    #         NUM_PATCHES += 1\n    #     if use_diffusion:\n    #         NUM_PATCHES += 1\n\n    #     if use_diffusion:\n    #         raise ValueError\n    #         # Sample random noise with shape equal to output action, used as the starting state for reverse diffusion\n    #         noise = torch.randn(\n    #             size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype\n    #         )\n\n    #         # Run diffusion-based prediction\n    #         normalized_actions, actions_hidden_states = self._run_diffusion_prediction(\n    #             input_embeddings,\n    #             all_actions_mask,\n    #             noise,\n    #             action_head,\n    #             projected_patch_embeddings,\n    #             labels,\n    #             attention_mask,\n    #             NUM_PATCHES,\n    #             NUM_PROMPT_TOKENS,\n    #             noisy_action_projector,\n    #         )\n    #     else:\n    #         # Run regression or discrete token-based prediction\n    #         # compute_logits = self._verl_discrete_compute_logits(\n    #         #     input_embeddings,\n    #         #     all_actions_mask,\n    #         #     projected_patch_embeddings,\n    #         #     attention_mask,\n    #         #     labels,\n    #         #     NUM_PATCHES,\n    #         #     NUM_PROMPT_TOKENS,\n    #         #     action_head,\n    #         # )\n            \n    #         #test\n            \n    #         all_actions_mask = all_actions_mask.unsqueeze(-1)  # (B, seq_len, 1)\n    #         input_embeddings = input_embeddings * ~all_actions_mask\n\n    #         # Build multimodal embeddings and attention mask\n    #         # multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(\n    #         #     input_embeddings, projected_patch_embeddings, attention_mask\n    #         # )\n    #         #test\n            \n    #         projected_patch_attention_mask = None\n    #         if attention_mask is not None:\n    #             projected_patch_attention_mask = torch.full(\n    #                 (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),\n    #                 fill_value=True,\n    #                 dtype=attention_mask.dtype,\n    #                 device=attention_mask.device,\n    #             )\n\n    #         # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)\n    #         multimodal_embeddings = torch.cat(\n    #             [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1\n    #         )\n\n    #         multimodal_attention_mask = None\n    #         if attention_mask is not None:\n    #             multimodal_attention_mask = torch.cat(\n    #                 [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1\n    #             )\n\n    #         #return multimodal_embeddings, multimodal_attention_mask\n            \n    #         #test end\n\n    #         # Forward pass through language model\n    #         language_model_output = self.language_model(\n    #             input_ids=None,\n    #             attention_mask=multimodal_attention_mask,\n    #             position_ids=None,\n    #             past_key_values=None,\n    #             inputs_embeds=multimodal_embeddings,\n    #             labels=None,\n    #             use_cache=None,\n    #             output_attentions=False,\n    #             output_hidden_states=False,\n    #             return_dict=True,\n    #         )\n\n        \n    #         compute_logits = language_model_output.logits[\n    #                     :,\n    #                     NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,\n    #                 ]\n                \n    #         #test end\n\n    #     return compute_logits\n    \n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        pixel_values=None,\n        attention_mask=None,\n        #labels=None,\n        proprio=None,\n        proprio_projector=None,\n        action_head=None,\n        noisy_action_projector=None,\n        use_film: bool = False,\n        **kwargs: str,\n    ) :\n        \"\"\"Predict actions from input sequence, with options for different prediction methods.\n\n        Args:\n            input_ids: Input token ids\n            unnorm_key: Key for unnormalization statistics\n            proprio: Proprioceptive features\n            proprio_projector: Projector for proprioceptive features\n            action_head: Optional head for L1 regression or diffusion-based prediction\n            noisy_action_projector: Projector for noisy actions in diffusion-based prediction\n            use_film: Whether to use FiLM conditioning\n            **kwargs: Additional arguments including pixel_values and attention_mask\n\n        Returns:\n            Tuple of (unnormalized_actions, action_hidden_states)\n        \"\"\"\n        # If the special empty token ('') does not already appear after the colon (':') token in the prompt\n        # (after \"OUT:\" or \"ASSISTANT:\"), insert it to match the inputs seen at training time\n        # if not torch.all(input_ids[:, -1] == 29871):\n        #     input_ids = torch.cat(\n        #         (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1\n        #     )\n        \n        #pixel_values = kwargs[\"pixel_values\"]\n        #attention_mask = kwargs[\"attention_mask\"]\n        \n        # Create fake labels tensor (needed for action mask)\n        labels = input_ids.clone()\n        labels[:] = IGNORE_INDEX\n\n        # # Get number of tokens in prompt (excluding the start token)\n        NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1  # Subtract action tokens and stop token\n\n\n        # # Prepare inputs by adding necessary tokens\n        # #input_ids, attention_mask = self._prepare_input_for_action_prediction_verl(input_ids, attention_mask)\n        \n        # #test\n        placeholder_action_token_ids = (\n            torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)\n        )\n        input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)\n\n        # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)\n        stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX\n        input_ids = torch.cat([input_ids, stop_token_id], dim=-1)\n\n        # Extend the attention mask to fit the new shape of input\n        # Note: Only batch size == 1 supported right now\n        mask_extension = (\n            torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))\n            .to(attention_mask.device)\n            .to(attention_mask.dtype)\n        )\n        attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)\n\n        ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1\n        labels_extension = (\n            torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)\n            * ARBITRARY_ACTION_TOKEN_IDX\n        )\n        labels = torch.cat([labels, labels_extension], dim=-1)\n\n        # # Replace last label token with stop token\n        labels[:, -1] = STOP_INDEX\n\n        \n        # Get input embeddings and action masks\n        \n        #NUM_PROMPT_TOKENS = kwargs[\"num_prompt_tokens\"]\n        \n        input_embeddings = self.get_input_embeddings()(input_ids)\n        \n        \n        #all_actions_mask = self._process_action_masks(labels)\n        #test\n        #current_action_mask = get_current_action_mask(labels)\n        newline_positions = labels != IGNORE_INDEX\n\n        # Calculate cumulative sum to identify regions between newlines\n        cumsum = torch.cumsum(newline_positions, dim=1)\n\n        # Create the mask\n        mask = (1 <= cumsum) & (cumsum <= ACTION_DIM)\n\n        # Extract the action part only\n        action_tokens_only_mask = labels > ACTION_TOKEN_BEGIN_IDX\n        current_action_mask = action_tokens_only_mask * mask\n\n        #next_actions_mask = get_next_actions_mask(labels)\n        newline_positions = labels != IGNORE_INDEX\n\n        # Calculate cumulative sum to identify regions between newlines\n        cumsum = torch.cumsum(newline_positions, dim=1)\n\n        # Create the mask\n        mask = cumsum > ACTION_DIM\n\n        # Extract the action part only\n        action_tokens_only_mask = labels > ACTION_TOKEN_BEGIN_IDX\n        next_actions_mask = action_tokens_only_mask * mask\n        \n        all_actions_mask = current_action_mask | next_actions_mask  # (B, seq_len)\n        \n        #test end\n        \n        # Extract language embeddings\n        language_embeddings = input_embeddings[~all_actions_mask].reshape(\n            input_embeddings.shape[0], -1, input_embeddings.shape[2]\n        )\n\n        # Process vision features\n        #projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)\n        #test\n        if use_film:\n            # FiLM: Infuse language inputs into visual features\n            raise ValueError\n            patch_features = self.vision_backbone(pixel_values, language_embeddings)  # (bsz, 256 * num_images, D)\n        else:\n            patch_features = self.vision_backbone(pixel_values)  # (bsz, 256 * num_images, D)\n\n        projected_patch_embeddings = self.projector(patch_features)\n        #test end\n        \n        \n        # Add proprioceptive features if provided\n        use_proprio = proprio_projector is not None and proprio is not None\n        if use_proprio:\n            proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)\n            projected_patch_embeddings = self._process_proprio_features(\n                projected_patch_embeddings, proprio, proprio_projector\n            )\n\n        # Use diffusion if provided, otherwise use regression or discrete prediction\n        use_diffusion = noisy_action_projector is not None and hasattr(action_head, \"noise_scheduler\")\n\n        # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)\n        NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()\n        if use_proprio:\n            NUM_PATCHES += 1\n        if use_diffusion:\n            NUM_PATCHES += 1\n\n        if use_diffusion:\n            raise ValueError\n            # Sample random noise with shape equal to output action, used as the starting state for reverse diffusion\n            noise = torch.randn(\n                size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype\n            )\n\n            # Run diffusion-based prediction\n            normalized_actions, actions_hidden_states = self._run_diffusion_prediction(\n                input_embeddings,\n                all_actions_mask,\n                noise,\n                action_head,\n                projected_patch_embeddings,\n                labels,\n                attention_mask,\n                NUM_PATCHES,\n                NUM_PROMPT_TOKENS,\n                noisy_action_projector,\n            )\n        else:\n            # Run regression or discrete token-based prediction\n            # compute_logits = self._verl_discrete_compute_logits(\n            #     input_embeddings,\n            #     all_actions_mask,\n            #     projected_patch_embeddings,\n            #     attention_mask,\n            #     labels,\n            #     NUM_PATCHES,\n            #     NUM_PROMPT_TOKENS,\n            #     action_head,\n            # )\n            \n            #test\n            \n            all_actions_mask = all_actions_mask.unsqueeze(-1)  # (B, seq_len, 1)\n            input_embeddings = input_embeddings * ~all_actions_mask\n\n            # Build multimodal embeddings and attention mask\n            # multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(\n            #     input_embeddings, projected_patch_embeddings, attention_mask\n            # )\n            #test\n            \n            projected_patch_attention_mask = None\n            if attention_mask is not None:\n                projected_patch_attention_mask = torch.full(\n                    (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),\n                    fill_value=True,\n                    dtype=attention_mask.dtype,\n                    device=attention_mask.device,\n                )\n\n            # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)\n            multimodal_embeddings = torch.cat(\n                [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1\n            )\n\n            multimodal_attention_mask = None\n            if attention_mask is not None:\n                multimodal_attention_mask = torch.cat(\n                    [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1\n                )\n\n            #return multimodal_embeddings, multimodal_attention_mask\n            \n            #test end\n\n            # Forward pass through language model\n            language_model_output = self.language_model(\n                input_ids=None,\n                attention_mask=multimodal_attention_mask,\n                position_ids=None,\n                past_key_values=None,\n                inputs_embeds=multimodal_embeddings,\n                labels=None,\n                use_cache=None,\n                output_attentions=False,\n                output_hidden_states=False,\n                return_dict=True,\n            )\n\n        \n            compute_logits = language_model_output.logits[\n                        :,\n                        NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,\n                    ]\n                \n            #test end\n\n        return compute_logits\n    \n    \n  \nclass OpenVLAForActionPrediction(PrismaticForConditionalGeneration):\n    config_class: PretrainedConfig = OpenVLAConfig\n\n    def __init__(self, config: OpenVLAConfig) -> None:\n        super().__init__(config)\n        self.norm_stats = config.norm_stats\n\n        # Compute action bins\n        self.bins = np.linspace(-1, 1, config.n_action_bins)\n        self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0\n\n        # Compute vocab size for de-tokenization -- revert added \"multiple of\"\n        self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of\n\n    def _prepare_input_for_action_prediction(self, input_ids, attention_mask):\n        \"\"\"Prepares input for action prediction by adding necessary tokens\"\"\"\n        # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens\n        placeholder_action_token_ids = (\n            torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)\n        )\n        input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)\n\n        # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)\n        stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX\n        input_ids = torch.cat([input_ids, stop_token_id], dim=-1)\n\n        # Extend the attention mask to fit the new shape of input\n        # Note: Only batch size == 1 supported right now\n        mask_extension = (\n            torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))\n            .to(attention_mask.device)\n            .to(attention_mask.dtype)\n        )\n        attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)\n\n        return input_ids, attention_mask\n\n    def _prepare_labels_for_action_prediction(self, labels, input_ids):\n        \"\"\"Creates labels tensor for action prediction if not provided\"\"\"\n        # Extend labels tensor with fake action labels\n        ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1\n        labels_extension = (\n            torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)\n            * ARBITRARY_ACTION_TOKEN_IDX\n        )\n        labels = torch.cat([labels, labels_extension], dim=-1)\n\n        # Replace last label token with stop token\n        labels[:, -1] = STOP_INDEX\n\n        return labels\n\n    def _unnormalize_actions(self, normalized_actions, unnorm_key=None):\n        \"\"\"Unnormalize actions using dataset statistics\"\"\"\n        action_norm_stats = self.get_action_stats(unnorm_key)\n\n        if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:\n            mask = action_norm_stats.get(\"mask\", np.ones_like(action_norm_stats[\"min\"], dtype=bool))\n            action_high, action_low = np.array(action_norm_stats[\"max\"]), np.array(action_norm_stats[\"min\"])\n        elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:\n            mask = action_norm_stats.get(\"mask\", np.ones_like(action_norm_stats[\"q01\"], dtype=bool))\n            action_high, action_low = np.array(action_norm_stats[\"q99\"]), np.array(action_norm_stats[\"q01\"])\n        else:\n            raise ValueError(\"Unsupported action/proprio normalization type detected!\")\n\n        actions = np.where(\n            mask,\n            0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low,\n            normalized_actions,\n        )\n\n        return actions\n\n    def _run_diffusion_prediction(\n        self,\n        input_embeddings,\n        all_actions_mask,\n        noise,\n        action_head,\n        projected_patch_embeddings,\n        labels,\n        attention_mask,\n        NUM_PATCHES,\n        NUM_PROMPT_TOKENS,\n        noisy_action_projector,\n    ):\n        \"\"\"Run diffusion-based action prediction\"\"\"\n        # Set diffusion timestep values\n        action_head.noise_scheduler.set_timesteps(action_head.num_diffusion_steps)\n        # Clone embedding for reuse in each timestep\n        orig_projected_patch_embeddings = projected_patch_embeddings.clone()\n        curr_noisy_actions = noise\n\n        # Reverse diffusion: Iteratively denoise to generate action prediction\n        for t in action_head.noise_scheduler.timesteps:\n            # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action\n            # embedding, and diffusion timestep embedding)\n            timesteps = torch.Tensor([t]).to(labels.device)\n            diffusion_timestep_embeddings = (\n                action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)\n            )  # (B, llm_dim)\n            diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1)  # (B, 1, llm_dim)\n\n            # [Diffusion] Replace the embeddings of the action tokens with noisy actions\n            # (Later on, the positional embeddings will be added to them)\n\n            # For simplicity, append diffusion timestep embedding to the end of projected vision tokens\n            projected_patch_embeddings = torch.cat(\n                (orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1\n            )\n\n            # Reshape and project noisy actions into language embedding space\n            B = curr_noisy_actions.shape[0]\n            orig_curr_noisy_actions_shape = curr_noisy_actions.shape\n            curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)\n            noisy_action_features = noisy_action_projector(curr_noisy_actions)\n            curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)\n\n            # Replace action token embeddings with noisy action embeddings\n            input_embeddings = self._replace_input_embeddings(\n                input_embeddings.clone(), all_actions_mask, noisy_action_features\n            )\n\n            # Build multimodal embeddings and attention mask\n            multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(\n                input_embeddings, projected_patch_embeddings, attention_mask\n            )\n\n            # Forward pass through language model\n            language_model_output = self.language_model(\n                input_ids=None,\n                attention_mask=multimodal_attention_mask,\n                position_ids=None,\n                past_key_values=None,\n                inputs_embeds=multimodal_embeddings,\n                labels=None,\n                use_cache=None,\n                output_attentions=False,\n                output_hidden_states=True,\n                return_dict=True,\n            )\n\n            # Extract hidden states for action portion of response\n            last_hidden_states = language_model_output.hidden_states[-1]  # (B, seq_len, D)\n            actions_hidden_states = last_hidden_states[\n                :,\n                NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,\n                :,\n            ]  # (B, act_chunk_len, D)\n\n            # Predict noise and update noisy actions: x_t -> x_{t-1}\n            noise_pred = action_head.predict_noise(actions_hidden_states)\n            curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample\n\n        curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)\n\n        # Return final actions\n        return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states\n\n    def _regression_or_discrete_prediction(\n        self,\n        input_embeddings,\n        all_actions_mask,\n        projected_patch_embeddings,\n        attention_mask,\n        labels,\n        NUM_PATCHES,\n        NUM_PROMPT_TOKENS,\n        action_head=None,\n    ):\n        \"\"\"Run L1 regression-based continuous action prediction or discrete action tokens prediction.\"\"\"\n        # Zero out action token embeddings\n        all_actions_mask = all_actions_mask.unsqueeze(-1)  # (B, seq_len, 1)\n        input_embeddings = input_embeddings * ~all_actions_mask\n\n        # Build multimodal embeddings and attention mask\n        multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(\n            input_embeddings, projected_patch_embeddings, attention_mask\n        )\n\n        # Forward pass through language model\n        language_model_output = self.language_model(\n            input_ids=None,\n            attention_mask=multimodal_attention_mask,\n            position_ids=None,\n            past_key_values=None,\n            inputs_embeds=multimodal_embeddings,\n            labels=None,\n            use_cache=None,\n            output_attentions=False,\n            output_hidden_states=True,\n            return_dict=True,\n        )\n\n        # Extract hidden states for action tokens\n        last_hidden_states = language_model_output.hidden_states[-1]  # (B, seq_len, D)\n        actions_hidden_states = last_hidden_states[\n            :,\n            NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,\n            :,\n        ]  # (B, act_chunk_len, D)\n\n        # Handle different prediction methods\n        if action_head is not None:\n            # L1 regression prediction\n            normalized_actions = action_head.predict_action(actions_hidden_states)\n            normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)\n            normalized_actions = normalized_actions.float().cpu().detach().numpy()\n        else:\n            # Discrete token-based prediction\n            predicted_action_token_ids = (\n                language_model_output.logits[\n                    :,\n                    NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,\n                ]\n                .argmax(dim=2)\n                .cpu()\n                .numpy()\n            )\n            discretized_actions = self.vocab_size - predicted_action_token_ids\n            discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)\n            normalized_actions = self.bin_centers[discretized_actions]\n            normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)\n\n        return normalized_actions, actions_hidden_states\n    \n    def _verl_discrete_prediction(\n        self,\n        input_embeddings,\n        all_actions_mask,\n        projected_patch_embeddings,\n        attention_mask,\n        labels,\n        NUM_PATCHES,\n        NUM_PROMPT_TOKENS,\n        action_head=None,\n        do_sample=True,\n        temperature=1,\n    ):\n        \"\"\"Run L1 regression-based continuous action prediction or discrete action tokens prediction.\"\"\"\n        # Zero out action token embeddings\n        all_actions_mask = all_actions_mask.unsqueeze(-1)  # (B, seq_len, 1)\n        input_embeddings = input_embeddings * ~all_actions_mask\n\n        # Build multimodal embeddings and attention mask\n        multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(\n            input_embeddings, projected_patch_embeddings, attention_mask\n        )\n\n        # Forward pass through language model\n        language_model_output = self.language_model(\n            input_ids=None,\n            attention_mask=multimodal_attention_mask,\n            position_ids=None,\n            past_key_values=None,\n            inputs_embeds=multimodal_embeddings,\n            labels=None,\n            use_cache=None,\n            output_attentions=False,\n            output_hidden_states=False,\n            return_dict=True,\n        )\n\n        # Extract hidden states for action tokens\n        #last_hidden_states = language_model_output.hidden_states[-1]  # (B, seq_len, D)\n        # actions_hidden_states = last_hidden_states[\n        #     :,\n        #     NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,\n        #     :,\n        # ]  # (B, act_chunk_len, D)\n\n        # Handle different prediction methods\n        # if action_head is not None:\n        #     # L1 regression prediction\n        #     normalized_actions = action_head.predict_action(actions_hidden_states)\n        #     normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)\n        #     normalized_actions = normalized_actions.float().cpu().detach().numpy()\n        # else:\n        # Discrete token-based prediction\n        \n        #test \n        # NUM_PROMPT_TOKENS = NUM_PROMPT_TOKENS + NUM_PATCHES\n        # j = torch.arange(language_model_output.logits.shape[1], device=NUM_PROMPT_TOKENS.device)\n        # start = NUM_PROMPT_TOKENS.unsqueeze(1)\n        # end = start + ACTION_DIM * NUM_ACTIONS_CHUNK\n        # mask_2d = (j >= start) & (j < end)\n        # mask = mask_2d.unsqueeze(-1) \n        # actions_masks = mask.expand_as(language_model_output.logits)  \n        \n        \n        NUM_PROMPT_TOKENS = NUM_PROMPT_TOKENS + NUM_PATCHES\n        batch_size = language_model_output.logits.shape[0]\n        device = language_model_output.logits.device\n\n       \n        start_indices = NUM_PROMPT_TOKENS.unsqueeze(1)  # [batch_size, 1]\n        position_offsets = torch.arange(ACTION_DIM * NUM_ACTIONS_CHUNK, device=device).unsqueeze(0)  # [1, seq_length]\n        seq_indices = start_indices + position_offsets  # [batch_size, ACTION_DIM*NUM_ACTIONS_CHUNK]\n        #test end\n        #test add\n        #print(\"language_model_output\",language_model_output.logits.shape[-1])\n        #print(\"self.vocab_size\",self.vocab_size) 32000\n        #topk_values, topk_indices = torch.topk(language_model_output.logits, k=256, dim=-1)\n        #print(topk_indices)\n        #assert language_model_output.logits.shape[-1] == self.vocab_size\n        #test add\n        if do_sample == False:\n            #org\n            # reponse_ids = language_model_output.logits[\n            #         :,\n            #         NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,\n            #     ].argmax(dim=2)\n            #reponse_ids = language_model_output.logits[actions_masks].argmax(dim=2)\n            #org end\n            \n            #padding\n            # reponse_ids = language_model_output.logits[\n            #     torch.arange(batch_size, device=device).unsqueeze(-1),  \n            #     seq_indices, \n            #     :\n            # ].argmax(dim=2)  \n            #padding end\n            \n            #padding + only get last 256 token\n            reponse_ids_logits = language_model_output.logits[\n                torch.arange(batch_size, device=device).unsqueeze(-1),  \n                seq_indices, \n                :\n            ]\n            start_index = self.vocab_size - 256 \n            response_last256 = reponse_ids_logits[..., -256-64:-64]  # Shape: [batch_size, seq_len, 256]\n            last256_argmax = response_last256.argmax(dim=-1)  # Shape: [batch_size, seq_len]\n            reponse_ids = last256_argmax + start_index  # Shape: [batch_size, seq_len]\n            #padding + only get last 256 token end\n            \n            predicted_action_token_ids = reponse_ids.cpu().numpy()\n                \n        else:\n            assert temperature>0\n            #org \n            # action_logits  = language_model_output.logits[\n            #         :,\n            #         NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,\n            #     ]\n            #action_logits = language_model_output.logits[actions_masks]\n            #org end\n            \n            action_logits = language_model_output.logits[\n                torch.arange(batch_size, device=device).unsqueeze(-1),  \n                seq_indices, \n                :\n            ]  \n            # padding \n            # scaled_logits = action_logits / temperature\n            # probs = torch.softmax(scaled_logits, dim=-1)\n            # probs_flat = probs.reshape(-1, probs.shape[-1])  # (B*act_chunk_len, vocab_size)\n            # sampled_indices_flat = torch.multinomial(probs_flat, num_samples=1)  # (B*act_chunk_len, 1)\n            # reponse_ids = sampled_indices_flat.view(action_logits.shape[0], -1)\n            # padding end \n            \n            #padding + only get last 256 token\n            action_logits_last256 = action_logits[..., -256-64:-64]\n            scaled_logits = action_logits_last256 / temperature\n            probs = torch.softmax(scaled_logits, dim=-1)\n            assert probs.shape[-1] == 256\n            probs_flat = probs.reshape(-1, probs.shape[-1])\n            sampled_indices_flat = torch.multinomial(probs_flat, num_samples=1)\n            original_ids_flat = sampled_indices_flat + (self.vocab_size - 256)\n            reponse_ids = original_ids_flat.view(action_logits.shape[0], -1)\n            #padding + only get last 256 token end\n            \n            predicted_action_token_ids = reponse_ids.cpu().numpy()\n     \n        discretized_actions = self.vocab_size - predicted_action_token_ids\n        discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)\n        normalized_actions = self.bin_centers[discretized_actions]\n        #normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)\n        normalized_actions = normalized_actions.reshape(-1, ACTION_DIM)\n\n        return normalized_actions, reponse_ids\n        #return normalized_actions, actions_hidden_states\n\n    \n\n\n    def predict_action(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        unnorm_key: Optional[str] = None,\n        proprio=None,\n        proprio_projector=None,\n        action_head=None,\n        noisy_action_projector=None,\n        use_film: bool = False,\n        **kwargs: str,\n    ) -> np.ndarray:\n        \"\"\"Predict actions from input sequence, with options for different prediction methods.\n\n        Args:\n            input_ids: Input token ids\n            unnorm_key: Key for unnormalization statistics\n            proprio: Proprioceptive features\n            proprio_projector: Projector for proprioceptive features\n            action_head: Optional head for L1 regression or diffusion-based prediction\n            noisy_action_projector: Projector for noisy actions in diffusion-based prediction\n            use_film: Whether to use FiLM conditioning\n            **kwargs: Additional arguments including pixel_values and attention_mask\n\n        Returns:\n            Tuple of (unnormalized_actions, action_hidden_states)\n        \"\"\"\n        # If the special empty token ('') does not already appear after the colon (':') token in the prompt\n        # (after \"OUT:\" or \"ASSISTANT:\"), insert it to match the inputs seen at training time\n        if not torch.all(input_ids[:, -1] == 29871):\n            input_ids = torch.cat(\n                (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1\n            )\n\n        pixel_values = kwargs[\"pixel_values\"]\n        attention_mask = kwargs[\"attention_mask\"]\n\n        # Create fake labels tensor (needed for action mask)\n        labels = input_ids.clone()\n        labels[:] = IGNORE_INDEX\n\n        # Get number of tokens in prompt (excluding the start token)\n        NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1  # Subtract action tokens and stop token\n\n        # Prepare inputs by adding necessary tokens\n        input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)\n\n        # Update labels tensor for action mask computation later\n        labels = self._prepare_labels_for_action_prediction(labels, input_ids)\n\n        # Get input embeddings and action masks\n        input_embeddings = self.get_input_embeddings()(input_ids)\n        all_actions_mask = self._process_action_masks(labels)\n\n        # Extract language embeddings\n        language_embeddings = input_embeddings[~all_actions_mask].reshape(\n            input_embeddings.shape[0], -1, input_embeddings.shape[2]\n        )\n\n        # Process vision features\n        projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)\n\n        # Add proprioceptive features if provided\n        use_proprio = proprio_projector is not None and proprio is not None\n        if use_proprio:\n            proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)\n            projected_patch_embeddings = self._process_proprio_features(\n                projected_patch_embeddings, proprio, proprio_projector\n            )\n\n        # Use diffusion if provided, otherwise use regression or discrete prediction\n        use_diffusion = noisy_action_projector is not None and hasattr(action_head, \"noise_scheduler\")\n\n        # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)\n        NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()\n        if use_proprio:\n            NUM_PATCHES += 1\n        if use_diffusion:\n            NUM_PATCHES += 1\n\n        if use_diffusion:\n            # Sample random noise with shape equal to output action, used as the starting state for reverse diffusion\n            noise = torch.randn(\n                size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype\n            )\n\n            # Run diffusion-based prediction\n            normalized_actions, actions_hidden_states = self._run_diffusion_prediction(\n                input_embeddings,\n                all_actions_mask,\n                noise,\n                action_head,\n                projected_patch_embeddings,\n                labels,\n                attention_mask,\n                NUM_PATCHES,\n                NUM_PROMPT_TOKENS,\n                noisy_action_projector,\n            )\n        else:\n            # Run regression or discrete token-based prediction\n            normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(\n                input_embeddings,\n                all_actions_mask,\n                projected_patch_embeddings,\n                attention_mask,\n                labels,\n                NUM_PATCHES,\n                NUM_PROMPT_TOKENS,\n                action_head,\n            )\n\n        # Unnormalize predicted actions\n        actions = self._unnormalize_actions(normalized_actions, unnorm_key)\n\n        return actions, actions_hidden_states\n\n    def generate_action_verl(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        unnorm_key: Optional[str] = None,\n        proprio=None,\n        proprio_projector=None,\n        action_head=None,\n        noisy_action_projector=None,\n        use_film: bool = False,\n        **kwargs: str,\n    ) -> np.ndarray:\n        \"\"\"Predict actions from input sequence, with options for different prediction methods.\n\n        Args:\n            input_ids: Input token ids\n            unnorm_key: Key for unnormalization statistics\n            proprio: Proprioceptive features\n            proprio_projector: Projector for proprioceptive features\n            action_head: Optional head for L1 regression or diffusion-based prediction\n            noisy_action_projector: Projector for noisy actions in diffusion-based prediction\n            use_film: Whether to use FiLM conditioning\n            **kwargs: Additional arguments including pixel_values and attention_mask\n\n        Returns:\n            Tuple of (unnormalized_actions, action_hidden_states)\n        \"\"\"\n        # If the special empty token ('') does not already appear after the colon (':') token in the prompt\n        # (after \"OUT:\" or \"ASSISTANT:\"), insert it to match the inputs seen at training time\n        # if not torch.all(input_ids[:, -1] == 29871):\n        #     input_ids = torch.cat(\n        #         (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1\n        #     )\n\n        pixel_values = kwargs[\"pixel_values\"]\n        attention_mask = kwargs[\"attention_mask\"]\n        do_sample = kwargs[\"do_sample\"]\n        temperature = kwargs[\"temperature\"]\n        \n        # Create fake labels tensor (needed for action mask)\n        labels = input_ids.clone()\n        labels[:] = IGNORE_INDEX\n\n        # Get number of tokens in prompt (excluding the start token)\n        #NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1  # Subtract action tokens and stop token\n        #test\n        padding_idx = kwargs[\"padding_idx\"]\n        num_prompt_tokens = input_ids.ne(padding_idx).sum(dim=1) - 1\n        #test end\n        \n\n        # Prepare inputs by adding necessary tokens\n        input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)\n\n        # Update labels tensor for action mask computation later\n        labels = self._prepare_labels_for_action_prediction(labels, input_ids)\n        \n        #here to convert padding from before to last\n        #test\n        padding_mask = input_ids.ne(padding_idx)\n        assert torch.all(padding_mask==attention_mask.ne(0))\n        #print(\"in predict_action padding_mask:\", padding_mask)\n        padding_mask = padding_mask.int() \n        sorted_indices = torch.argsort(padding_mask, dim=1, descending=True, stable=True)\n        input_ids = torch.gather(input_ids, 1, sorted_indices)\n        attention_mask = torch.gather(attention_mask, 1, sorted_indices)\n        labels = torch.gather(labels, 1, sorted_indices)\n        assert use_film==False\n        #test end\n        \n\n        # Get input embeddings and action masks\n        input_embeddings = self.get_input_embeddings()(input_ids)\n        all_actions_mask = self._process_action_masks(labels)\n\n        # Extract language embeddings\n        language_embeddings = input_embeddings[~all_actions_mask].reshape(\n            input_embeddings.shape[0], -1, input_embeddings.shape[2]\n        )\n\n        # Process vision features\n        projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)\n\n        # Add proprioceptive features if provided\n        use_proprio = proprio_projector is not None and proprio is not None\n        if use_proprio:\n            proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)\n            projected_patch_embeddings = self._process_proprio_features(\n                projected_patch_embeddings, proprio, proprio_projector\n            )\n\n        # Use diffusion if provided, otherwise use regression or discrete prediction\n        use_diffusion = noisy_action_projector is not None and hasattr(action_head, \"noise_scheduler\")\n\n        # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)\n        NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()\n        if use_proprio:\n            NUM_PATCHES += 1\n        if use_diffusion:\n            NUM_PATCHES += 1\n\n        if use_diffusion:\n            raise ValueError\n            # Sample random noise with shape equal to output action, used as the starting state for reverse diffusion\n            noise = torch.randn(\n                size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype\n            )\n\n            # Run diffusion-based prediction\n            normalized_actions, actions_hidden_states = self._run_diffusion_prediction(\n                input_embeddings,\n                all_actions_mask,\n                noise,\n                action_head,\n                projected_patch_embeddings,\n                labels,\n                attention_mask,\n                NUM_PATCHES,\n                NUM_PROMPT_TOKENS,\n                noisy_action_projector,\n            )\n        else:\n            # Run regression or discrete token-based prediction\n            normalized_actions, reponse_ids = self._verl_discrete_prediction(\n                input_embeddings,\n                all_actions_mask,\n                projected_patch_embeddings,\n                attention_mask,\n                labels,\n                NUM_PATCHES,\n                num_prompt_tokens,\n                action_head,\n                do_sample=do_sample,\n                temperature=temperature,\n            )\n\n        # Unnormalize predicted actions\n        actions = self._unnormalize_actions(normalized_actions, unnorm_key)\n        \n        actions = actions.reshape(-1 ,NUM_ACTIONS_CHUNK, ACTION_DIM)\n        #\n        return actions, reponse_ids\n\n    \n    \n    @staticmethod\n    def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:\n        \"\"\"Validate and resolve the unnormalization key for action statistics\"\"\"\n        if unnorm_key is None:\n            assert len(norm_stats) == 1, (\n                f\"Your model was trained on more than one dataset, \"\n                f\"please pass a `unnorm_key` from the following options to choose the statistics \"\n                f\"used for un-normalizing actions: {norm_stats.keys()}\"\n            )\n            unnorm_key = next(iter(norm_stats.keys()))\n\n        assert unnorm_key in norm_stats, (\n            f\"The `unnorm_key` you chose is not in the set of available dataset statistics, \"\n            f\"please choose from: {norm_stats.keys()}\"\n        )\n        return unnorm_key\n\n    def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:\n        \"\"\"Get the dimensionality of the policy's action space.\"\"\"\n        unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)\n        return len(self.norm_stats[unnorm_key][\"action\"][\"min\"])\n\n    def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:\n        \"\"\"Get all the logged statistics for the given dataset.\"\"\"\n        unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)\n        return self.norm_stats[unnorm_key][\"action\"]\n"
  },
  {
    "path": "siirl/models/embodied/openvla_oft/processing_prismatic.py",
    "content": "\"\"\"\nprocessing_prismatic.py\n\nHuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration\nspecifies `siglip-224px+7b`.\n\"\"\"\n\nfrom typing import Any, ClassVar, List, Optional, Tuple, Union\n\nimport timm.data\nimport torch\nimport torchvision.transforms.functional as TVF\nfrom PIL import Image\nfrom torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor\nfrom transformers import PreTrainedTokenizerBase\nfrom transformers.image_processing_utils import BatchFeature, ImageProcessingMixin\nfrom transformers.processing_utils import ProcessorMixin\nfrom transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy\nfrom transformers.utils import TensorType\n\n\n# === Image Processing ===\ndef letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image:\n    \"\"\"Given a PIL.Image, pad to square by adding a symmetric border around the height/width.\"\"\"\n    (w, h), max_wh = image.size, max(image.size)\n    horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)\n    padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)\n\n    return TVF.pad(image, padding, fill=padding_fill_value, padding_mode=\"constant\")\n\n\nclass PrismaticImageProcessor(ImageProcessingMixin):\n    model_input_names: ClassVar[List[str]] = [\"pixel_values\"]\n\n    def __init__(\n        self,\n        use_fused_vision_backbone: bool = False,\n        image_resize_strategy: str = \"letterbox\",\n        input_sizes: Optional[List[Tuple[int, int, int]]] = None,\n        interpolations: Optional[List[str]] = None,\n        means: Optional[List[Tuple[float, float, float]]] = None,\n        stds: Optional[List[Tuple[float, float, float]]] = None,\n        **kwargs: str,\n    ) -> None:\n        \"\"\"\n        Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be\n        created by TIMM, and edited to follow our custom `image_resize_strategy` logic.\n        @param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone\n        @param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox >\n        @param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height)\n        @param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: \"bicubic\")\n        @param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`)\n        @param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`)\n        \"\"\"\n        self.use_fused_vision_backbone = use_fused_vision_backbone\n        self.image_resize_strategy = image_resize_strategy\n\n        # Handle `None` default values\n        input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes\n        means = [(0.5, 0.5, 0.5)] if means is None else means\n        stds = [(0.5, 0.5, 0.5)] if stds is None else stds\n\n        # TIMM `data_cfg` Parameters\n        self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds\n\n        # Grab torchvision transforms via TIMM =>> need to parse for specific \"functional\" transform values!\n        self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], []\n        self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None\n\n        for idx in range(len(input_sizes)):\n            transform = timm.data.create_transform(\n                input_size=self.input_sizes[idx],\n                interpolation=self.interpolations[idx],\n                mean=self.means[idx],\n                std=self.stds[idx],\n                crop_pct=1.0,  # Set to 1.0 to ignore cropping (initial Resize sets `input_size`)\n                crop_mode=\"center\",  # Default crop mode -- no-op when `crop_pct == 1.0`\n                is_training=False,  # No image augmentations when loading the transform!\n            )\n\n            # [Validation] Ensure appropriate transform structure, expected sizes\n            if not (\n                isinstance(transform, Compose)\n                and (len(transform.transforms) == 4)\n                and isinstance(transform.transforms[0], Resize)\n                and isinstance(transform.transforms[1], CenterCrop)\n                and isinstance(transform.transforms[2], ToTensor)\n                and isinstance(transform.transforms[3], Normalize)\n                and (transform.transforms[0].size == self.input_sizes[idx][-1])\n                and (transform.transforms[1].size == self.input_sizes[idx][-2:])\n            ):\n                raise ValueError(f\"Unexpected TIMM image transformation structure/sizes: `{transform}`\")\n\n            # HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.\n            #   => Instead, we're going to parse the transform and call \"torchvision.transforms.functional\" (`tvf`)\n            resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3]\n            self.tvf_resize_params.append(\n                {\n                    \"size\": resize_t.size,\n                    \"interpolation\": TVF.pil_modes_mapping[resize_t.interpolation],\n                    \"max_size\": None,\n                    \"antialias\": True,\n                }\n            )\n            self.tvf_crop_params.append({\"output_size\": crop_t.size})\n            self.tvf_normalize_params.append(\n                {\n                    \"mean\": norm_t.mean.float().numpy().tolist(),\n                    \"std\": norm_t.std.float().numpy().tolist(),\n                    \"inplace\": False,\n                }\n            )\n            self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None\n\n            # Handle Prismatic `image_resize_strategy`\n            if self.image_resize_strategy == \"resize-naive\":\n                self.tvf_resize_params[idx][\"size\"] = (resize_t.size, resize_t.size)\n            elif self.image_resize_strategy == \"letterbox\":\n                self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]])\n            elif self.image_resize_strategy == \"resize-crop\":\n                pass\n            else:\n                raise ValueError(f\"Image resize strategy `{self.image_resize_strategy}` is not supported!\")\n\n        # Dispatch **kwargs to super()\n        super().__init__(**kwargs)\n\n    def apply_transform(self, img: Image.Image) -> torch.Tensor:\n        \"\"\"Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])\"\"\"\n        if self.tvf_do_letterbox:\n            img = letterbox_pad_transform(img, self.tvf_letterbox_fill)\n\n        # [Contract] Fused Backbones expect \"channel-stacked\" inputs; we'll unpack on the model side!\n        imgs_t = []\n        for idx in range(len(self.input_sizes)):\n            img_idx = TVF.resize(img, **self.tvf_resize_params[idx])\n            img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx])\n            img_idx_t = TVF.to_tensor(img_idx)\n            img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx])\n            imgs_t.append(img_idx_t)\n\n        # [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0\n        img_t = torch.vstack(imgs_t)\n\n        return img_t\n\n    def preprocess(\n        self,\n        images: Union[Image.Image, List[Image.Image]],\n        return_tensors: Optional[Union[str, TensorType]] = None,\n        **_: str,\n    ) -> BatchFeature:\n        \"\"\"\n        Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we\n        explicitly only handle PIL.Image.Image instances for simplicity.\n        @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.\n        @param return_tensors: BatchFeature default Tensor format (e.g., \"pt\" for torch); if None, returns np.ndarray\n        @return: Instance of `transformers :: BatchFeature` with a single key \"pixel_values\"\n        \"\"\"\n        if not isinstance(images, list):\n            images = [images]\n\n        # Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into \"batched\" Tensor\n        pixel_values = torch.stack([self.apply_transform(img.convert(\"RGB\")) for img in images])\n\n        # Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert\n        return BatchFeature(data={\"pixel_values\": pixel_values.float().numpy()}, tensor_type=return_tensors)\n\n    def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:\n        return self.preprocess(images, **kwargs)\n\n\n# === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer ===\n#   =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py\nclass PrismaticProcessor(ProcessorMixin):\n    attributes: ClassVar[List[str]] = [\"image_processor\", \"tokenizer\"]\n    image_processor_class: str = \"AutoImageProcessor\"\n    tokenizer_class: str = \"AutoTokenizer\"\n\n    def __init__(\n        self,\n        image_processor: Optional[ImageProcessingMixin] = None,\n        tokenizer: Optional[PreTrainedTokenizerBase] = None,\n    ) -> None:\n        super().__init__(image_processor, tokenizer)\n\n    def __call__(\n        self,\n        text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],\n        images: Union[Image.Image, List[Image.Image]],\n        padding: Union[bool, str, PaddingStrategy] = False,\n        truncation: Optional[Union[bool, str, TruncationStrategy]] = None,\n        max_length: Optional[int] = None,\n        return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,\n    ) -> BatchFeature:\n        \"\"\"\n        Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer,\n        forwards images to PrismaticImageProcessor.\n        @param text: The (batch) of text to encode; must be a string or list of strings.\n        @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.\n        @param padding: Sequence padding strategy (if multiple specified) in < True = \"longest\" | \"max_length\" | False >\n        @param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified\n        @param max_length: Maximum length (in tokens) to truncate\n        @param return_tensors: Type of return tensors (usually \"pt\" or TensorType.PYTORCH)\n        @return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.\n        \"\"\"\n        pixel_values = self.image_processor(images, return_tensors=return_tensors)[\"pixel_values\"]\n        text_inputs = self.tokenizer(\n            text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length\n        )\n\n        # [Validate] Need same number of images and text inputs!\n        if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:\n            raise ValueError(\"Batch is malformed; expected same number of images and text inputs!\")\n\n        return BatchFeature(data={**text_inputs, \"pixel_values\": pixel_values})\n\n    # === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===\n    def batch_decode(\n        self,\n        sequences: Union[List[int], List[List[int]], torch.Tensor, Any],  # `Any` = np.ndarray | tf.Tensor\n        skip_special_tokens: bool = False,\n        clean_up_tokenization_spaces: Optional[bool] = None,\n        **kwargs: str,\n    ) -> List[str]:\n        return self.tokenizer.batch_decode(\n            sequences=sequences,\n            skip_special_tokens=skip_special_tokens,\n            clean_up_tokenization_spaces=clean_up_tokenization_spaces,\n            **kwargs,\n        )\n\n    def decode(\n        self,\n        token_ids: Union[int, List[int], torch.Tensor, Any],  # `Any` = np.ndarray | tf.Tensor\n        skip_special_tokens: bool = False,\n        clean_up_tokenization_spaces: Optional[bool] = None,\n        **kwargs: str,\n    ) -> str:\n        return self.tokenizer.decode(\n            token_ids=token_ids,\n            skip_special_tokens=skip_special_tokens,\n            clean_up_tokenization_spaces=clean_up_tokenization_spaces,\n            **kwargs,\n        )\n\n    @property\n    def model_input_names(self) -> List[str]:\n        tokenizer_input_names = self.tokenizer.model_input_names\n        image_processor_input_names = self.image_processor.model_input_names\n\n        return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))\n"
  },
  {
    "path": "siirl/models/embodied/openvla_oft/train_utils.py",
    "content": "\"\"\"Utils for training/fine-tuning scripts.\"\"\"\n\nimport torch\n\nfrom .constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX\n\n\ndef get_current_action_mask(token_ids):\n    # Create a tensor marking positions of IGNORE_INDEX\n    newline_positions = token_ids != IGNORE_INDEX\n\n    # Calculate cumulative sum to identify regions between newlines\n    cumsum = torch.cumsum(newline_positions, dim=1)\n\n    # Create the mask\n    mask = (1 <= cumsum) & (cumsum <= ACTION_DIM)\n\n    # Extract the action part only\n    action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX\n    mask = action_tokens_only_mask * mask\n\n    return mask\n\n\ndef get_next_actions_mask(token_ids):\n    # Create a tensor marking positions of IGNORE_INDEX\n    newline_positions = token_ids != IGNORE_INDEX\n\n    # Calculate cumulative sum to identify regions between newlines\n    cumsum = torch.cumsum(newline_positions, dim=1)\n\n    # Create the mask\n    mask = cumsum > ACTION_DIM\n\n    # Extract the action part only\n    action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX\n    mask = action_tokens_only_mask * mask\n\n    return mask\n\n\ndef compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask):\n    correct_preds = (predicted_token_ids == ground_truth_token_ids) & mask\n    accuracy = correct_preds.sum().float() / mask.sum().float()\n    return accuracy\n\n\ndef compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask):\n    pred_continuous_actions = torch.tensor(\n        action_tokenizer.decode_token_ids_to_actions(predicted_token_ids[mask].cpu().numpy())\n    )\n    true_continuous_actions = torch.tensor(\n        action_tokenizer.decode_token_ids_to_actions(ground_truth_token_ids[mask].cpu().numpy())\n    )\n    l1_loss = torch.nn.functional.l1_loss(pred_continuous_actions, true_continuous_actions)\n    return l1_loss\n"
  },
  {
    "path": "siirl/models/llama/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/models/llama/megatron/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .modeling_llama_megatron import (\n    ParallelLlamaForCausalLM,\n    # rmpad with megatron\n    ParallelLlamaForCausalLMRmPad,\n    # rmpad with megatron and pipeline parallelism\n    ParallelLlamaForCausalLMRmPadPP,\n    ParallelLlamaForValueRmPad,\n    ParallelLlamaForValueRmPadPP,\n    # original model with megatron\n    ParallelLlamaModel,\n)\n\n__all__ = [\n    \"ParallelLlamaForCausalLM\",\n    \"ParallelLlamaForCausalLMRmPad\",\n    \"ParallelLlamaForCausalLMRmPadPP\",\n    \"ParallelLlamaForValueRmPad\",\n    \"ParallelLlamaForValueRmPadPP\",\n    \"ParallelLlamaModel\",\n]\n"
  },
  {
    "path": "siirl/models/llama/megatron/checkpoint_utils/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/models/llama/megatron/checkpoint_utils/llama_loader.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\n\nfrom siirl.utils.extras.device import get_device_id, get_torch_device\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    print(f\"get megatron data parallel size: {mpu.get_data_parallel_world_size()}\")\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False):\n    \"\"\"Load merged state_dict to sharded Megatron module in training.\"\"\"\n    from megatron.core import DistributedDataParallel as LocalDDP\n    from megatron.core import mpu\n    from megatron.core.transformer.module import Float16Module\n    from torch.nn.parallel import DistributedDataParallel as torchDDP\n\n    from siirl.utils.megatron.megatron_utils import print_rank_0, unwrap_model\n\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    def fetch_params(module):\n        for param in module.parameters():\n            torch.distributed.fetch(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group())\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if torch.distributed.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, (list, tuple)):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f\"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}\"\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        gpt_model_module = _get_gpt_model(models[i])\n        assert len(gpt_model_module.model.layers) == num_layers_per_model\n\n    def _fetch_tensor(tensor, name) -> torch.Tensor:\n        \"\"\"fetch tensor\"\"\"\n        nonlocal state_dict\n        if tensor is not None:\n            tensor.data.copy_(state_dict[name])\n\n    def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards\"\"\"\n        nonlocal state_dict\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if name in state_dict:\n            full_weight = state_dict[name]\n\n            if mutate_func is not None:\n                full_weight = mutate_func(full_weight)\n            tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n            if tensor is not None:\n                tensor.data.copy_(tensor_chunk[tp_rank])\n        else:\n            print(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards\"\"\"\n        nonlocal state_dict\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if name in state_dict:\n            full_weight = state_dict[name]\n\n            if mutate_func is not None:\n                full_weight = mutate_func(full_weight)\n            tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n            if tensor is not None:\n                tensor.data.copy_(tensor_chunk[tp_rank])\n        else:\n            print(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:\n        \"\"\"fetch gate_up tensor in tp shards\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if gate_name in state_dict and up_name in state_dict:\n            gate_weight = state_dict[gate_name]\n            up_weight = state_dict[up_name]\n            new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id())\n            for i in range(tp_size):\n                intermediate_size_tp = config.intermediate_size // tp_size\n                gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0))\n\n            tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)\n            if tensor is not None:\n                tensor.data.copy_(tensor_chunk[tp_rank])\n        else:\n            print(f\"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        assert q_name in state_dict and k_name in state_dict and v_name in state_dict\n        full_weight_q = state_dict[q_name]\n        full_weight_k = state_dict[k_name]\n        full_weight_v = state_dict[v_name]\n\n        hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n        if config.num_key_value_heads >= tp_size:\n            q_size_tp = config.hidden_size // tp_size\n            kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n            total_size = q_size_tp + 2 * kv_size_tp\n            new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id())\n            for i in range(tp_size):\n                q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]\n                v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]\n                new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))\n\n        else:\n            q_size_tp = config.hidden_size // tp_size\n            kv_size_tp = hidden_size_per_head\n            total_size = q_size_tp + 2 * kv_size_tp\n            new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id())\n            for i in range(tp_size):\n                q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head\n                end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head\n                k_part = full_weight_k[start_idx:end_idx]\n                v_part = full_weight_v[start_idx:end_idx]\n                new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))\n\n        tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)\n        if tensor is not None:\n            tensor.data.copy_(tensor_chunk[tp_rank])\n\n    # Embeddings\n    # -------------------\n    print_rank_0(\"loading embeddings...\")\n    gpt_model_module = _get_gpt_model(models[0])\n    embed_tokens_weight = None\n    if pp_rank == 0:\n        embed_tokens_weight = gpt_model_module.model.embed_tokens.weight\n    _fetch_tp_shard_tensor_vocab(embed_tokens_weight, \"model.embed_tokens.weight\")\n\n    # Transformer layers\n    # -------------------\n    layer_map = _megatron_calc_layer_map(config)\n\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    num_layer_per_pp = config.num_hidden_layers // pp_size\n    vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n\n    layer_list = []\n    if vpp_size is not None:\n        for vpp_rank in range(vpp_size):\n            num_layer_vpp_chunk = num_layer_per_pp // vpp_size\n            num_layer_this_model = num_layer_vpp_chunk\n            offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + (mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk)\n            layer_list.extend(list(range(offset, offset + num_layer_this_model)))\n    else:\n        num_layer_this_model = num_layer_per_pp\n        offset = pp_rank * num_layer_per_pp\n        layer_list.extend(list(range(offset, offset + num_layer_this_model)))\n\n    for layer in layer_list:\n        print_rank_0(f\"loading layer #{layer}...\")\n        layer_name = f\"model.layers.{layer}\"\n        dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]\n\n        gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])\n        sync_layer = gpt_model_module.model.layers[dst_layer_idx]\n\n        _fetch_tensor(\n            sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.input_layernorm.weight\",\n        )\n\n        _fetch_tp_shard_tensor_qkv(\n            sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.self_attn.q_proj.weight\",\n            f\"{layer_name}.self_attn.k_proj.weight\",\n            f\"{layer_name}.self_attn.v_proj.weight\",\n        )\n\n        _fetch_tp_shard_tensor(\n            sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.self_attn.o_proj.weight\",\n            chunk_dim=1,\n        )\n\n        _fetch_tensor(\n            sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.post_attention_layernorm.weight\",\n        )\n\n        _fetch_tp_shard_tensor_gate_up(\n            sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.mlp.gate_proj.weight\",\n            f\"{layer_name}.mlp.up_proj.weight\",\n        )\n\n        _fetch_tp_shard_tensor(\n            sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.mlp.down_proj.weight\",\n            chunk_dim=1,\n        )\n    # Final Layernorm\n    # -------------------\n    print_rank_0(\"loading final layernorm...\")\n    gpt_model_module = _get_gpt_model(models[-1])\n    _fetch_tensor(\n        getattr(gpt_model_module.model.norm, \"weight\", None),\n        \"model.norm.weight\",\n    )\n\n    print_rank_0(\"loading lm_head...\")\n    if pp_rank + 1 == pp_size:\n        lm_head_weight = gpt_model_module.lm_head.weight\n\n        if is_value_model:\n            if \"lm_head.weight\" in state_dict and state_dict[\"lm_head.weight\"].shape[0] == 1:\n                _fetch_tensor(lm_head_weight, \"lm_head.weight\")\n                print_rank_0(\"load lm_head weight\")\n            elif \"reward_head.weight\" in state_dict and state_dict[\"reward_head.weight\"].shape[0] == 1:\n                _fetch_tensor(lm_head_weight, \"reward_head.weight\")\n                print_rank_0(\"load lm_head from value_head weight\")\n            else:\n                _fetch_tensor(None, \"lm_head.weight\")\n                print_rank_0(\"fail to match lm_head in value_model\")\n        else:\n            _fetch_tp_shard_tensor(lm_head_weight, \"lm_head.weight\")\n\n    dist.barrier()\n    get_torch_device().empty_cache()\n    print_rank_0(f\"loading megatron ckpt done, time elapsed {time.time() - start_time}s\")\n"
  },
  {
    "path": "siirl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\n\nfrom siirl.utils.extras.device import get_device_id, get_torch_device\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    print(f\"get megatron data parallel size: {mpu.get_data_parallel_world_size()}\")\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef load_state_dict_to_megatron_llama(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False):\n    \"\"\"Load merged state_dict to sharded Megatron module in training.\"\"\"\n    from megatron.core import DistributedDataParallel as LocalDDP\n    from megatron.core import mpu\n    from megatron.core.transformer.module import Float16Module\n    from torch.nn.parallel import DistributedDataParallel as torchDDP\n\n    from siirl.utils.megatron.megatron_utils import print_rank_0, unwrap_model\n\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    def broadcast_params(module):\n        for param in module.parameters():\n            torch.distributed.broadcast(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group())\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if torch.distributed.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, (list, tuple)):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f\"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}\"\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        gpt_model_module = _get_gpt_model(models[i])\n        assert len(gpt_model_module.model.layers) == num_layers_per_model\n\n    def _broadcast_tensor(tensor, name) -> torch.Tensor:\n        \"\"\"broadcast tensor from rank0 across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                weight = state_dict[name]\n                tensor_shape = weight.shape\n            else:\n                tensor_shape = None\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not in state_dict, skip load\")\n            return\n\n        if tensor is None:\n            tensor = torch.empty(\n                tensor_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        if torch.distributed.get_rank() == 0:\n            tensor.data.copy_(weight)\n        dist.broadcast(tensor, src=0, group=mp_group)\n\n    def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                full_weight = state_dict[name]\n\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                full_weight = state_dict[name]\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            gate_weight = state_dict[gate_name]\n            up_weight = state_dict[up_name]\n            new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id())\n            for i in range(tp_size):\n                intermediate_size_tp = config.intermediate_size // tp_size\n                gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0))\n\n            tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, f\"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}\"\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            assert q_name in state_dict and k_name in state_dict and v_name in state_dict\n            full_weight_q = state_dict[q_name]\n            full_weight_k = state_dict[k_name]\n            full_weight_v = state_dict[v_name]\n\n            hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id())\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))\n\n            else:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id())\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head\n                    end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head\n                    k_part = full_weight_k[start_idx:end_idx]\n                    v_part = full_weight_v[start_idx:end_idx]\n                    new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))\n\n            tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, f\"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}\"\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    if dp_rank == 0:\n        # Embeddings\n        # -------------------\n        print_rank_0(\"loading embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        embed_tokens_weight = None\n        if pp_rank == 0:\n            embed_tokens_weight = gpt_model_module.model.embed_tokens.weight\n        _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, \"model.embed_tokens.weight\")\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n\n        for layer in range(config.num_hidden_layers):\n            print_rank_0(f\"loading layer #{layer}...\")\n            layer_name = f\"model.layers.{layer}\"\n            dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])\n            sync_layer = gpt_model_module.model.layers[dst_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.input_layernorm.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                chunk_dim=1,\n            )\n\n            _broadcast_tensor(\n                sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                chunk_dim=1,\n            )\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"loading final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.model.norm, \"weight\", None),\n            \"model.norm.weight\",\n        )\n\n        print_rank_0(\"loading lm_head...\")\n        lm_head_weight = None\n        if pp_rank + 1 == pp_size:\n            lm_head_weight = gpt_model_module.lm_head.weight\n\n        if is_value_model:\n            if \"lm_head.weight\" in state_dict and state_dict[\"lm_head.weight\"].shape[0] == 1:\n                _broadcast_tensor(lm_head_weight, \"lm_head.weight\")\n                print_rank_0(\"load lm_head weight\")\n            elif \"reward_head.weight\" in state_dict and state_dict[\"reward_head.weight\"].shape[0] == 1:\n                _broadcast_tensor(lm_head_weight, \"reward_head.weight\")\n                print_rank_0(\"load lm_head from value_head weight\")\n            else:\n                _broadcast_tensor(None, \"lm_head.weight\")\n                print_rank_0(\"fail to match lm_head in value_model\")\n        else:\n            _broadcast_tp_shard_tensor(lm_head_weight, \"lm_head.weight\")\n    dist.barrier()\n    # Broadcast weights inside data parallel groups\n    for wrapped_model in wrapped_models:\n        broadcast_params(wrapped_model)\n\n    get_torch_device().empty_cache()\n    print_rank_0(f\"loading megatron ckpt done, time elapsed {time.time() - start_time}s\")\n"
  },
  {
    "path": "siirl/models/llama/megatron/checkpoint_utils/llama_saver.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\nfrom megatron.core import mpu\nfrom megatron.core.distributed import DistributedDataParallel as LocalDDP\nfrom megatron.core.transformer.module import Float16Module\nfrom torch.nn.parallel import DistributedDataParallel as torchDDP\n\nfrom siirl.utils.extras.device import get_device_id, get_torch_device\nfrom siirl.utils.megatron.megatron_utils import print_rank_0, unwrap_model\n\n\ndef _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):\n    \"\"\"given TP,DP,PP rank to get the global rank.\"\"\"\n\n    tp_size = mpu.get_tensor_model_parallel_world_size()\n    dp_size = mpu.get_data_parallel_world_size()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), f\"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}\"\n    # We only support TP-DP-PP grouping, for correctness when resharding\n    return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):\n    \"\"\"Merge sharded parameters of a Megatron module into a merged checkpoint.\n\n    Args:\n        wrapped_models (list of megatron.core.distributed.DistributedDataParallel):\n            The local DDP wrapped megatron modules.\n        config (str or None):\n            HF config for model\n        dtype: model params type\n        is_value_model: if model is value model\n        tie_word_embeddings: tie_word_embeddings, not used in llama, only to keep same interface with qwen2\n    Returns:\n        state_dict (dict):\n            The merged state_dict in rank 0, and an empty dictionary in other ranks.\n    \"\"\"\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if dist.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, (list, tuple)):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        assert len(models[i].model.layers) == num_layers_per_model, \"len model layers {} not equal to num_layers_per_model {}\".format(len(models[i].model.layers), num_layers_per_model)\n\n    state_dict = dict()\n\n    def _get_cpu_tensor(tensor: torch.Tensor):\n        if tensor is None:\n            return None\n        if tensor.device == torch.device(\"cpu\"):\n            return tensor.detach().clone()\n        return tensor.detach().cpu()\n\n    def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        if torch.distributed.get_rank() == src_rank:\n            if tensor is None:\n                weight = None\n                tensor_shape = None\n            else:\n                weight = tensor\n                tensor_shape = weight.shape\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not exist, skip collect\")\n            return\n\n        if weight is None:\n            weight = torch.empty(\n                tensor_shape,\n                dtype=dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n\n        dist.broadcast(weight, src=src_rank, group=mp_group)\n\n        if torch.distributed.get_rank() == 0:\n            state_dict[name] = _get_cpu_tensor(weight)\n\n    def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=concat_dim)\n            if mutate_func is not None:\n                full_tensor = mutate_func(full_tensor)\n            state_dict[name] = full_tensor\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            intermediate_size_tp = config.intermediate_size // tp_size\n            gate_weight_list = []\n            up_weight_list = []\n            for i in range(tp_size):\n                gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]\n                gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]\n                up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]\n                gate_weight_list.append(gate_weight_tp)\n                up_weight_list.append(up_weight_tp)\n\n            state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)\n            state_dict[up_name] = torch.cat(up_weight_list, dim=0)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            q_weight_list = []\n            k_weight_list = []\n            v_weight_list = []\n            hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_part = qkv_part[:q_size_tp]\n                    k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]\n                    v_part = qkv_part[q_size_tp + kv_size_tp : total_size]\n                    q_weight_list.append(q_part)\n                    k_weight_list.append(k_part)\n                    v_weight_list.append(v_part)\n            else:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_part = qkv_part[:q_size_tp]\n                    k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]\n                    v_part = qkv_part[q_size_tp + kv_size_tp : total_size]\n                    q_weight_list.append(q_part)\n                    if i * config.num_key_value_heads % tp_size == 0:\n                        k_weight_list.append(k_part)\n                        v_weight_list.append(v_part)\n\n            state_dict[q_name] = torch.cat(q_weight_list, dim=0)\n            state_dict[k_name] = torch.cat(k_weight_list, dim=0)\n            state_dict[v_name] = torch.cat(v_weight_list, dim=0)\n\n    # empty cache before collecting weights\n    get_torch_device().empty_cache()\n    # Embeddings\n    # -------------------\n    if dp_rank == 0:\n        # Embeddings\n        # -------------------\n        print_rank_0(\"collecting embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        _broadcast_tp_shard_tensor(\n            gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None,\n            \"model.embed_tokens.weight\",\n            src_pp_rank=0,\n        )\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n        for layer in range(config.num_hidden_layers):\n            print_rank_0(f\"collecting layer #{layer}...\")\n            layer_name = f\"model.layers.{layer}\"\n            src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])\n            sync_layer = gpt_model_module.model.layers[src_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.input_layernorm.weight,\n                f\"{layer_name}.input_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.weight,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attn.o_proj.weight,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tensor(\n                sync_layer.post_attention_layernorm.weight,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.gate_up_proj.weight,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.down_proj.weight,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"collecting final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.model.norm, \"weight\", None),\n            \"model.norm.weight\",\n            src_pp_rank=pp_size - 1,\n        )\n\n        print_rank_0(\"collecting lm_head...\")\n\n        if is_value_model:\n            if pp_rank == pp_size - 1:\n                print(f\"gpt_model_module.lm_head.weight: {gpt_model_module.lm_head.weight.shape}\")\n            _broadcast_tensor(\n                gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None,\n                \"lm_head.weight\",\n                src_pp_rank=pp_size - 1,\n            )\n            _broadcast_tensor(\n                gpt_model_module.reward_head.weight if pp_rank == pp_size - 1 and getattr(gpt_model_module, \"reward_weight\", None) is not None else None,\n                \"reward_head.weight\",\n                src_pp_rank=pp_size - 1,\n            )\n\n        else:\n            _broadcast_tp_shard_tensor(\n                getattr(gpt_model_module.lm_head, \"weight\", None) if pp_rank == pp_size - 1 else None,\n                \"lm_head.weight\",\n                src_pp_rank=pp_size - 1,\n            )\n\n    dist.barrier()\n\n    get_torch_device().empty_cache()\n    if torch.distributed.get_rank() == 0:\n        if dtype not in [torch.float16, torch.bfloat16, torch.float32]:\n            print(f'Unknown/unsupported dtype to save: {dtype}\"')\n            exit(1)\n        for k, v in state_dict.items():\n            if dtype != v.dtype:\n                state_dict[k] = v.to(dtype)\n\n    print_rank_0(f\"merge megatron ckpt done, time elapsed {time.time() - start_time}s\")\n    return state_dict\n"
  },
  {
    "path": "siirl/models/llama/megatron/layers/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .parallel_attention import ParallelLlamaAttention\nfrom .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad\nfrom .parallel_linear import (\n    LinearForLastLayer,\n    MergedColumnParallelLinear,\n    QKVParallelLinear,\n)\nfrom .parallel_mlp import ParallelLlamaMLP\nfrom .parallel_rmsnorm import ParallelLlamaRMSNorm\n\n__all__ = [\"LinearForLastLayer\", \"MergedColumnParallelLinear\", \"QKVParallelLinear\", \"ParallelLlamaAttention\", \"ParallelLlamaDecoderLayer\", \"ParallelLlamaDecoderLayerRmPad\", \"ParallelLlamaMLP\", \"ParallelLlamaRMSNorm\"]\n"
  },
  {
    "path": "siirl/models/llama/megatron/layers/parallel_attention.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nfrom typing import Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom flash_attn.layers.rotary import apply_rotary_emb\nfrom megatron.core import ModelParallelConfig, tensor_parallel\nfrom megatron.core import parallel_state as mpu\nfrom torch import nn\nfrom transformers import LlamaConfig\nfrom transformers.utils import is_flash_attn_2_available\n\nfrom siirl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear\nfrom siirl.utils.megatron import tensor_parallel as tp_utils\n\n\nclass LlamaRotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        # Build here to make `torch.jit.trace` work.\n        self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype())\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n    def forward(self, x, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        if seq_len > self.max_seq_len_cached:\n            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)\n\n        return (\n            self.cos_cached[:seq_len].to(dtype=x.dtype),\n            self.sin_cached[:seq_len].to(dtype=x.dtype),\n        )\n\n\nclass LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):\n    \"\"\"LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev\"\"\"\n\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n        t = t / self.scaling_factor\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\nclass LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):\n    \"\"\"LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla\"\"\"\n\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n\n        if seq_len > self.max_position_embeddings:\n            base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))\n            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n            self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\nclass LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding):\n    def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__(dim, max_position_embeddings, base, device)\n\n        self.factor = config.rope_scaling[\"factor\"]  # `8` in the original implementation\n        self.high_freq_factor = config.rope_scaling[\"high_freq_factor\"]  # `1` in the original implementation\n        self.low_freq_factor = config.rope_scaling[\"low_freq_factor\"]  # `4` in the original implementation\n        self.old_context_len = config.rope_scaling[\"original_max_position_embeddings\"]  # `8192` in the original implementation\n\n        low_freq_wavelen = self.old_context_len / self.low_freq_factor\n        high_freq_wavelen = self.old_context_len / self.high_freq_factor\n\n        wavelen = 2 * math.pi / self.inv_freq\n        # wavelen < high_freq_wavelen: do nothing; wavelen > low_freq_wavelen: divide by factor\n        inv_freq_llama = torch.where(wavelen > low_freq_wavelen, self.inv_freq / self.factor, self.inv_freq)\n        # otherwise: interpolate between the two, using a smooth factor\n        smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / (self.high_freq_factor - self.low_freq_factor)\n        smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / self.factor + smooth_factor * inv_freq_llama\n        is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)\n        inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)\n\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        # Build here to make `torch.jit.trace` work.\n        self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype())\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids):\n    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]\n    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass ParallelLlamaAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config = config\n        self.megatron_config = megatron_config\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n\n        # assign values after tp\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        assert self.num_heads % tp_size == 0, f\"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}\"\n        assert self.num_key_value_heads % tp_size == 0, f\"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}\"\n\n        self.num_heads_per_tp = self.num_heads // tp_size\n        self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size\n        self.hidden_size_per_tp = self.hidden_size // tp_size\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads}).\")\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()\n\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            assert row_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)\n            tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)\n\n        # [self.q_size, self.k_size, self.v_size]\n        self.qkv_proj = QKVParallelLinear(\n            input_size=self.hidden_size,\n            num_heads=self.num_heads,\n            num_key_value_heads=self.num_key_value_heads,\n            head_dim=self.head_dim,\n            bias=config.attention_bias,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n        self.q_size = self.num_heads_per_tp * self.head_dim\n        self.k_size = self.num_key_value_heads_per_tp * self.head_dim\n        self.v_size = self.num_key_value_heads_per_tp * self.head_dim\n\n        self.o_proj = tensor_parallel.RowParallelLinear(\n            input_size=self.num_heads * self.head_dim,\n            output_size=self.hidden_size,\n            bias=config.attention_bias,\n            input_is_parallel=True,\n            skip_bias_add=False,\n            **row_kwargs,\n        )\n\n        self._init_rope()\n\n    def _init_rope(self):\n        if self.config.rope_scaling is None:\n            self.rotary_emb = LlamaRotaryEmbedding(\n                self.head_dim,\n                max_position_embeddings=self.max_position_embeddings,\n                base=self.rope_theta,\n            )\n        else:\n            rope_type_key = \"type\" if \"type\" in self.config.rope_scaling else \"rope_type\"\n            scaling_type = self.config.rope_scaling[rope_type_key]\n            scaling_factor = self.config.rope_scaling[\"factor\"]\n            if scaling_type == \"linear\":\n                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(\n                    self.head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                )\n            elif scaling_type == \"dynamic\":\n                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(\n                    self.head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                )\n            elif scaling_type == \"llama3\":\n                self.rotary_emb = LlamaLlama3ScalingRotaryEmbedding(\n                    self.head_dim,\n                    self.config,\n                    max_position_embeddings=self.max_position_embeddings,\n                    base=self.rope_theta,\n                )\n            else:\n                raise ValueError(f\"Unknown RoPE scaling type {scaling_type}\")\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n        qkv = self.qkv_proj(hidden_states)[0]\n        query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len):\n            raise ValueError(f\"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is {attn_weights.size()}\")\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\")\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim):\n            raise ValueError(f\"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is {attn_output.size()}\")\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp)\n        attn_output = self.o_proj(attn_output)[0]\n        return attn_output\n\n\n\"\"\"\nRemove padding Attention\n- Using Flash-attn 2\n- Compatible with sequence parallel\n\"\"\"\n\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_varlen_func\n    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\n\n\ndef apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length):\n    batch_size = position_ids.shape[0]\n\n    q = pad_input(q, indices, batch_size, sequence_length)  # (batch_size, seqlen, num_head, head_dim)\n    k = pad_input(k, indices, batch_size, sequence_length)\n    cos = cos[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]\n    sin = sin[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n\n    q_embed = index_first_axis(rearrange(q_embed, \"b s ... -> (b s) ...\"), indices)\n    k_embed = index_first_axis(rearrange(k_embed, \"b s ... -> (b s) ...\"), indices)\n\n    return q_embed, k_embed\n\n\n# use flash-attn rotary embeddings with rmpad\n# cos/sin shoudl be: (seq_length, rotary_dim / 2)\ndef apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen):\n    q_embed = apply_rotary_emb(q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)\n    k_embed = apply_rotary_emb(k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)\n    return q_embed, k_embed\n\n\nclass ParallelLlamaAttentionRmPad(ParallelLlamaAttention):\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: torch.Tensor = None,\n        max_seqlen_in_batch: int = None,\n    ):\n        total_nnz, _, _ = hidden_states.size()  # This is the total_nnz padded after sequence parallel\n\n        if self.megatron_config.sequence_parallel:\n            total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size()\n\n        qkv = self.qkv_proj(hidden_states)[0]\n        query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)  # (total_nnz, 1, hidden_size)\n\n        if self.megatron_config.sequence_parallel:\n            sequence_parallel_pad = total_nnz - cu_seqlens[-1]\n            total_nnz = cu_seqlens[-1]  # total_nnz before sp padding\n            query_states = query_states[:total_nnz]\n            key_states = key_states[:total_nnz]\n            value_states = value_states[:total_nnz]\n\n        # Flash attention requires the input to have the shape\n        # batch_size x seq_length x head_dime x hidden_dim\n        # therefore we just need to keep the original shape\n        query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim)\n        key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)\n        value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)\n\n        cos, sin = self.rotary_emb(value_states, seq_len=sequence_length)\n        cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2]  # flash attn only needs half\n        query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch)\n        # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices,\n\n        # TODO: llama does not have dropout in the config??\n        # It is recommended to use dropout with FA according to the docs\n        # when training.\n        dropout_rate = 0.0  # if not self.training else self.attn_dropout\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in float16 just to be sure everything works as expected.\n        # This might slowdown training & inference so it is recommended to not cast the LayerNorms\n        # in fp32. (LlamaRMSNorm handles it correctly)\n        input_dtype = query_states.dtype\n        if input_dtype == torch.float32:\n            query_states = query_states.to(torch.float16)\n            key_states = key_states.to(torch.float16)\n            value_states = value_states.to(torch.float16)\n\n        attn_output_unpad = flash_attn_varlen_func(\n            query_states,\n            key_states,\n            value_states,\n            cu_seqlens_q=cu_seqlens,\n            cu_seqlens_k=cu_seqlens,\n            max_seqlen_q=max_seqlen_in_batch,\n            max_seqlen_k=max_seqlen_in_batch,\n            dropout_p=dropout_rate,\n            softmax_scale=None,\n            causal=True,\n        )\n\n        attn_output_unpad = attn_output_unpad.to(input_dtype)\n        attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous()\n\n        # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled\n        # Here we need to repad\n        if self.megatron_config.sequence_parallel:\n            attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad))\n\n        attn_output_unpad = self.o_proj(attn_output_unpad)[0]\n        return attn_output_unpad\n"
  },
  {
    "path": "siirl/models/llama/megatron/layers/parallel_decoder.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional, Tuple\n\nimport torch\nfrom megatron.core import ModelParallelConfig\nfrom torch import nn\nfrom transformers import LlamaConfig\n\nfrom siirl.utils.megatron.megatron_utils import TransformerConfig, convert_config\n\nfrom siirl.models.llama.megatron.layers.parallel_attention import ParallelLlamaAttention, ParallelLlamaAttentionRmPad\nfrom siirl.models.llama.megatron.layers.parallel_mlp import ParallelLlamaMLP\nfrom siirl.models.llama.megatron.layers.parallel_rmsnorm import ParallelLlamaRMSNorm\n\n\nclass ParallelLlamaDecoderLayer(nn.Module):\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.layer_idx = layer_idx\n        self.hidden_size = config.hidden_size\n        self.self_attn = ParallelLlamaAttention(config=config, megatron_config=megatron_config)\n\n        self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)\n        self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)\n        self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Note: sequence parallel is hidden inside ColumnParallelLinear\n        # reduce scatter is hidden inside RowParallelLinear\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n        # TODO: add sequence parallel operator reduce_scatter here\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n\n        # TODO: add sequence parallel operator all_gather here\n\n        hidden_states = self.mlp(hidden_states)\n\n        # TODO: add sequence parallel operator reduce_scatter here\n\n        hidden_states = residual + hidden_states\n\n        outputs = hidden_states\n\n        return outputs\n\n\nclass ParallelLlamaDecoderLayerRmPad(nn.Module):\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.layer_idx = layer_idx\n        self.hidden_size = config.hidden_size\n        self.self_attn = ParallelLlamaAttentionRmPad(config=config, megatron_config=megatron_config)\n\n        self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config)\n        self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config)\n        self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        residual = hidden_states  # (total_nnz // sp, 1, hidden_size)\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size)\n        # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size)\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        # shape changes same as attn\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = hidden_states\n\n        return outputs\n"
  },
  {
    "path": "siirl/models/llama/megatron/layers/parallel_linear.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py\n\nimport torch\nfrom megatron.core import tensor_parallel\n\n\nclass QKVParallelLinear(tensor_parallel.ColumnParallelLinear):\n    def __init__(\n        self,\n        input_size,\n        num_heads,\n        num_key_value_heads,\n        head_dim,\n        *,\n        bias=True,\n        gather_output=True,\n        skip_bias_add=False,\n        **kwargs,\n    ):\n        # Keep input parameters, and already restrict the head numbers\n        self.input_size = input_size\n        self.q_output_size = num_heads * head_dim\n        self.kv_output_size = num_key_value_heads * head_dim\n        self.head_dim = head_dim\n        self.gather_output = gather_output\n        self.skip_bias_add = skip_bias_add\n\n        input_size = self.input_size\n        output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim\n\n        super().__init__(\n            input_size=input_size,\n            output_size=output_size,\n            bias=bias,\n            gather_output=gather_output,\n            skip_bias_add=skip_bias_add,\n            **kwargs,\n        )\n\n\nclass MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):\n    def __init__(\n        self,\n        input_size,\n        gate_ouput_size,\n        up_output_size,\n        *,\n        bias=True,\n        gather_output=True,\n        skip_bias_add=False,\n        **kwargs,\n    ):\n        # Keep input parameters, and already restrict the head numbers\n        self.input_size = input_size\n        self.output_size = gate_ouput_size + up_output_size\n        self.gather_output = gather_output\n        self.skip_bias_add = skip_bias_add\n\n        super().__init__(\n            input_size=self.input_size,\n            output_size=self.output_size,\n            bias=bias,\n            gather_output=gather_output,\n            skip_bias_add=skip_bias_add,\n            **kwargs,\n        )\n\n\nclass LinearForLastLayer(torch.nn.Linear):\n    def __init__(\n        self,\n        input_size,\n        output_size,\n        *,\n        config,\n        bias=True,\n    ):\n        super().__init__(in_features=input_size, out_features=output_size, bias=bias)\n        self.sequence_parallel = config.sequence_parallel\n        if self.sequence_parallel:\n            self.weight.sequence_parallel = True\n\n    def forward(\n        self,\n        input_,\n        weight=None,\n        runtime_gather_output=None,\n    ):\n        logits = super().forward(input_)\n        logits = logits.float()\n        if self.sequence_parallel:\n            logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)\n        return logits, None\n"
  },
  {
    "path": "siirl/models/llama/megatron/layers/parallel_mlp.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom megatron.core import ModelParallelConfig, tensor_parallel\nfrom megatron.core import parallel_state as mpu\nfrom torch import nn\nfrom transformers.activations import ACT2FN\n\nfrom siirl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear\nfrom siirl.utils.megatron import tensor_parallel as tp_utils\n\n\nclass ParallelLlamaMLP(nn.Module):\n    def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        # The weight is only [hidden_size, intermediate_size // model_parallel_world_size]\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()\n\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            assert row_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)\n            tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)\n\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        self.gate_up_proj = MergedColumnParallelLinear(\n            input_size=self.hidden_size,\n            gate_ouput_size=self.intermediate_size,\n            up_output_size=self.intermediate_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n        self.gate_size = self.intermediate_size // tp_size\n\n        self.down_proj = tensor_parallel.RowParallelLinear(\n            input_size=self.intermediate_size,\n            output_size=self.hidden_size,\n            bias=False,\n            input_is_parallel=True,\n            skip_bias_add=False,\n            **row_kwargs,\n        )\n\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        gate_up = self.gate_up_proj(x)[0]\n        gate, up = gate_up.split(self.gate_size, dim=-1)\n        return self.down_proj(self.act_fn(gate) * up)[0]\n"
  },
  {
    "path": "siirl/models/llama/megatron/layers/parallel_rmsnorm.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numbers\n\nimport torch\nfrom apex.normalization.fused_layer_norm import fused_rms_norm_affine\nfrom megatron.core import ModelParallelConfig\nfrom torch import nn\nfrom transformers import LlamaConfig\n\nfrom siirl.utils.megatron import sequence_parallel as sp_utils\n\n\nclass ParallelLlamaRMSNorm(nn.Module):\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        \"\"\"\n        LlamaRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        if isinstance(config.hidden_size, numbers.Integral):\n            normalized_shape = (config.hidden_size,)\n        self.normalized_shape = torch.Size(normalized_shape)\n        self.weight = nn.Parameter(torch.ones(self.normalized_shape))\n        self.variance_epsilon = config.rms_norm_eps\n\n        if megatron_config.sequence_parallel:\n            sp_utils.mark_parameter_as_sequence_parallel(self.weight)\n\n    def forward(self, hidden_states):\n        return fused_rms_norm_affine(\n            input=hidden_states,\n            weight=self.weight,\n            normalized_shape=self.normalized_shape,\n            eps=self.variance_epsilon,\n            memory_efficient=True,\n        )\n"
  },
  {
    "path": "siirl/models/llama/megatron/modeling_llama_megatron.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch LLaMA model with Megatron-style acceleration.\"\"\"\n\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.utils.checkpoint\nfrom megatron.core import ModelParallelConfig, mpu, tensor_parallel\nfrom torch import nn\nfrom transformers.modeling_outputs import BaseModelOutputWithPast\nfrom transformers.models.llama.configuration_llama import LlamaConfig\nfrom transformers.models.llama.modeling_llama import CausalLMOutputWithPast\n\nfrom siirl.utils.megatron import sequence_parallel as sp_utils\nfrom siirl.utils.megatron import tensor_parallel as tp_utils\nfrom siirl.utils.megatron.megatron_utils import TransformerConfig, convert_config\n\nfrom siirl.models.llama.megatron.layers import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad, ParallelLlamaRMSNorm\n\n\"\"\"\nTODO: \n1. Add weight initialization. Here we need to be careful on TP weight init.\n2. Add sequence parallel\n3. Load checkpoint from meta LLama pretrained checkpoint\n\"\"\"\n\n\n# Copied from transformers.models.bart.modeling_bart._make_causal_mask\ndef _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)\n\n\n# Copied from transformers.models.bart.modeling_bart._expand_mask\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\n\nclass ParallelLlamaModel(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n    Args:\n        config: LlamaConfig\n    \"\"\"\n\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)\n        self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs)\n\n        self.layers = nn.ModuleList([ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)])\n        self.norm = ParallelLlamaRMSNorm(config, megatron_config)\n\n    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask\n    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds):\n        # create causal mask\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        combined_attention_mask = None\n        if input_shape[-1] > 1:\n            combined_attention_mask = _make_causal_mask(\n                input_shape,\n                inputs_embeds.dtype,\n                device=inputs_embeds.device,\n            )\n\n        if attention_mask is not None:\n            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)\n            combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n\n        return combined_attention_mask\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (batch_size, seq_length)\n            attention_mask: attention_mask. shape (batch_size, seq_length)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        batch_size, seq_length = input_ids.shape\n        inputs_embeds = self.embed_tokens(input_ids)\n        # embed positions\n\n        attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds)\n\n        hidden_states = inputs_embeds\n\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n            )\n\n            hidden_states = layer_outputs\n\n        hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelLlamaForCausalLM(nn.Module):\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.model = ParallelLlamaModel(config, megatron_config=megatron_config)\n        self.vocab_size = config.vocab_size\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n        hidden_states = outputs\n        logits = self.lm_head(hidden_states)[0]\n\n        logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)\n\n        logits = logits.float()\n        return CausalLMOutputWithPast(\n            loss=None,\n            logits=logits,\n            past_key_values=None,\n            hidden_states=None,\n            attentions=None,\n        )\n\n\nfrom flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\n\n\nclass ParallelLlamaModelRmPad(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n    Args:\n        config: LlamaConfig\n    \"\"\"\n\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        self.megatron_config = megatron_config\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)\n        self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs)\n\n        self.layers = nn.ModuleList([ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)])\n        self.norm = ParallelLlamaRMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (1, totol_nnz)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        inputs_embeds = self.embed_tokens(input_ids)  # (1, total_nnz) -> (1, total_nnz, hidden_size)\n\n        # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)\n        inputs_embeds = inputs_embeds.transpose(0, 1)\n        if self.megatron_config.sequence_parallel:\n            inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)\n\n        hidden_states = inputs_embeds\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                position_ids=position_ids,\n                sequence_length=sequence_length,\n                indices=indices,\n                cu_seqlens=cu_seqlens,\n                max_seqlen_in_batch=max_seqlen_in_batch,\n            )\n\n            hidden_states = layer_outputs\n\n        hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelLlamaForCausalLMRmPad(nn.Module):\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.megatron_config = megatron_config\n        self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config)\n        self.vocab_size = config.vocab_size\n        self._init_head(config)\n\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n    def _forward_head(self, hidden_states):\n        # all_gather from sequence parallel region is performed inside lm_head\n        logits = self.lm_head(hidden_states)[0]\n        logits = logits.float()  # (total_nnz_padded, 1, vocab_size // tp)\n        logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)  # (total_nnz_padded, 1, vocab_size)\n        return logits\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n        batch_size, sequence_length = input_ids.shape\n\n        # remove padding here\n        input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)  # (total_nnz, 1)\n\n        # pad input_ids to multiple of tp for all tp ranks\n        # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap\n        if self.megatron_config.sequence_parallel:\n            input_ids = sp_utils.pad_to_sequence_parallel(input_ids)\n\n        input_ids = input_ids.transpose(0, 1)  # (1, total_nnz+pad)\n\n        outputs = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        hidden_states = outputs\n\n        logits = self._forward_head(hidden_states)\n\n        # remove padding from sequence parallel\n        if self.megatron_config.sequence_parallel:\n            totol_nnz = cu_seqlens[-1]\n            logits = logits[:totol_nnz]  # (total_nnz_padded)\n\n        logits = torch.squeeze(logits, dim=1)  # remove the artificial batch dimension\n        # add removed padding back\n        logits = pad_input(logits, indices, batch_size, seqlen=sequence_length)  # (batch_size, sequence_length, vocab_size)\n\n        return CausalLMOutputWithPast(\n            loss=None,\n            logits=logits,\n            past_key_values=None,\n            hidden_states=None,\n            attentions=None,\n        )\n\n\nclass ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad):\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)\n        # lm_head is effectively the same as sequence parallel\n        sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)\n\n    def _forward_head(self, hidden_states):\n        logits = self.lm_head(hidden_states)  # (total_nnz_padded // tp, 1, 1)\n        logits = logits.float()\n        if self.megatron_config.sequence_parallel:\n            logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)\n        return logits\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        output = super().forward(input_ids, attention_mask, position_ids)\n        output.logits = torch.squeeze(output.logits, dim=-1)\n        return output\n\n\n\"\"\"\nSupport pipeline parallelism\n\"\"\"\n\n\nclass ParallelLlamaModelRmPadPP(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n    This model definition supports pipeline parallelism. To support pp and vpp,\n    - This model only contains layer in this pp stage and vpp chunk\n    - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp.\n    Args:\n        config: LlamaConfig\n    \"\"\"\n\n    def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.pre_process = pre_process\n        self.post_process = post_process\n        self.megatron_config = megatron_config\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)\n        if pre_process:\n            self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs)\n        else:\n            self.embed_tokens = None\n\n        pp_rank = mpu.get_pipeline_model_parallel_rank()\n        pp_size = megatron_config.pipeline_model_parallel_size\n        self.num_layer_per_pp = config.num_hidden_layers // pp_size\n        vpp_size = megatron_config.virtual_pipeline_model_parallel_size\n        vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank()\n\n        if vpp_size is not None:\n            self.layers = nn.ModuleList()\n            self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size\n            self.num_layer_this_model = self.num_layer_vpp_chunk\n            offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk)\n        else:\n            self.num_layer_this_model = self.num_layer_per_pp\n            offset = pp_rank * self.num_layer_per_pp\n\n        self.layers = nn.ModuleList()\n        for i in range(self.num_layer_this_model):\n            layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config, layer_idx=offset + i)\n            self.layers.add_module(f\"{i}\", layer)\n\n        if post_process:\n            self.norm = ParallelLlamaRMSNorm(config, megatron_config)\n        else:\n            self.norm = None\n\n    def set_input_tensor(self, input_tensor):\n        \"\"\"Set input tensor to be used instead of forward()'s input.\n\n        When doing pipeline parallelism the input from the previous\n        stage comes from communication, not from the input, so the\n        model's forward_step_func won't have it. This function is thus\n        used by internal code to bypass the input provided by the\n        forward_step_func\"\"\"\n        self.input_tensor = input_tensor\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (1, totol_nnz)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        if self.pre_process:\n            inputs_embeds = self.embed_tokens(input_ids)  # (1, total_nnz) -> (1, total_nnz, hidden_size)\n\n            # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron\n            # so need to deal with it by handle here:\n            # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)\n            inputs_embeds = inputs_embeds.transpose(0, 1)\n            if self.megatron_config.sequence_parallel:\n                inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)\n\n            hidden_states = inputs_embeds\n        else:\n            # self.hidden_states should be passed by Megatron\n            hidden_states = self.input_tensor\n\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                position_ids=position_ids,\n                sequence_length=sequence_length,\n                indices=indices,\n                cu_seqlens=cu_seqlens,\n                max_seqlen_in_batch=max_seqlen_in_batch,\n            )\n\n            hidden_states = layer_outputs\n\n        if self.post_process:\n            hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelLlamaForCausalLMRmPadPP(nn.Module):\n    def __init__(\n        self,\n        config: LlamaConfig,\n        megatron_config: ModelParallelConfig,\n        pre_process,\n        post_process,\n        share_embeddings_and_output_weights=False,\n    ):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.megatron_config = megatron_config\n        self.model = ParallelLlamaModelRmPadPP(config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process)\n        assert share_embeddings_and_output_weights is False, \"Llama Model not supports sharing embedding and output weights\"\n        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights\n        self.vocab_size = config.vocab_size\n        self.pre_process = pre_process\n        self.post_process = post_process\n        if post_process:\n            self._init_head(config)\n\n    def set_input_tensor(self, input_tensor):\n        \"\"\"Set input tensor to be used instead of forward()'s input.\n\n        When doing pipeline parallelism the input from the previous\n        stage comes from communication, not from the input, so the\n        model's forward_step_func won't have it. This function is thus\n        used by internal code to bypass the input provided by the\n        forward_step_func\"\"\"\n        assert len(input_tensor) == 1\n        self.model.set_input_tensor(input_tensor[0])\n\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n    def _forward_head(self, hidden_states):\n        # all_gather from sequence parallel region is performed inside lm_head\n        # logits shape before forward_head hidden_states.shape: [4, 32, 4096]\n        logits = self.lm_head(hidden_states)[0]\n        # logits shape after forward_head logits.shape: [8, 32, 8]\n        logits = logits.float()  # (total_nnz_padded, 1, vocab_size // tp)\n        return logits\n\n    def forward(\n        self,\n        # original input\n        *,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n\n        # Note that input_ids, attention_mask and position_ids should be passed to every pp layer.\n        # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model\n        batch_size, sequence_length = input_ids.shape\n        # remove padding here\n        input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)  # (total_nnz, 1)\n\n        # pad input_ids to multiple of tp for all tp ranks\n        # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap\n        if self.megatron_config.sequence_parallel:\n            input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad)\n\n        input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz+pad)\n\n        outputs = self.model(\n            input_ids=input_ids_rmpad,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        if self.post_process:\n            hidden_states = outputs\n            # print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096])\n            logits = self._forward_head(hidden_states)\n            logits = torch.squeeze(logits, dim=1)  # remove the artificial batch dimension # torch.Size([8, 32, 16])\n\n            # remove padding from sequence parallel\n            if self.megatron_config.sequence_parallel:\n                totol_nnz = cu_seqlens[-1]\n                logits = logits[:totol_nnz]  # (total_nnz_padded)\n            # add removed padding back. If input is already rmpad, we let the caller pad_input\n            logits = pad_input(logits, indices, batch_size, seqlen=sequence_length)  # (batch_size, sequence_length, vocab_size)\n\n            return CausalLMOutputWithPast(\n                loss=None,\n                logits=logits,\n                past_key_values=None,\n                hidden_states=None,\n                attentions=None,\n            )\n        else:\n            return outputs\n\n\nclass ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP):\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)\n        # lm_head is effectively the same as sequence parallel\n        sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)\n\n    def _forward_head(self, hidden_states):\n        logits = self.lm_head(hidden_states)  # (total_nnz_padded // tp, 1, 1)\n        logits = logits.float()\n        if self.megatron_config.sequence_parallel:\n            logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)\n        return logits\n\n    def forward(\n        self,\n        *,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)\n        if self.post_process:\n            output.logits = torch.squeeze(output.logits, dim=-1)\n            return output\n        else:\n            return output\n"
  },
  {
    "path": "siirl/models/loader.py",
    "content": "from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict\n\nfrom transformers import (\n    AutoConfig,\n    AutoProcessor,\n    AutoTokenizer,\n)\n\nfrom siirl.utils.extras.misc import skip_check_imports, try_download_model_from_other_hub\nfrom siirl.models.patcher import patch_processor, patch_tokenizer\n\nfrom loguru import logger\n\n\nif TYPE_CHECKING:\n    from transformers import PreTrainedTokenizer, ProcessorMixin\n\n    from siirl.params import ModelArguments\n\n\nclass TokenizerModule(TypedDict):\n    tokenizer: \"PreTrainedTokenizer\"\n    processor: Optional[\"ProcessorMixin\"]\n\n\ndef _get_init_kwargs(model_args: \"ModelArguments\") -> Dict[str, Any]:\n    r\"\"\"\n    Gets arguments to load config/tokenizer/model.\n    Note: including inplace operation of model_args.\n    \"\"\"\n    skip_check_imports()\n    model_args.path = try_download_model_from_other_hub(model_args)\n    return {\n        \"trust_remote_code\": model_args.trust_remote_code,\n        \"cache_dir\": model_args.cache_dir,\n        \"revision\": model_args.model_revision,\n        \"token\": model_args.hf_hub_token,\n    }\n\n\ndef set_pad_token_id(tokenizer):\n    \"\"\"Set pad_token_id to eos_token_id if it is None.\n\n    Args:\n        tokenizer (transformers.PreTrainedTokenizer): The tokenizer to be set.\n\n    \"\"\"\n    if tokenizer.pad_token_id is None:\n        tokenizer.pad_token_id = tokenizer.eos_token_id\n        logger.warning(f\"tokenizer.pad_token_id is None. Now set to {tokenizer.eos_token_id}\")\n    if tokenizer.pad_token is None:\n        tokenizer.pad_token = tokenizer.eos_token\n        logger.warning(f\"tokenizer.pad_token is None. Now set to {tokenizer.eos_token}\")\n\n\ndef load_tokenizer(\n    path: str = \"\",\n    model_args: \"ModelArguments\" = None,\n    correct_pad_token: bool = True,\n    correct_gemma2=True,\n) -> \"TokenizerModule\":\n    r\"\"\"\n    Loads pretrained tokenizer and optionally loads processor.\n    Note: including inplace operation of model_args.\n    \"\"\"\n    init_kwargs = {}\n    if model_args is not None:\n        path = model_args.path\n        init_kwargs = _get_init_kwargs(model_args)\n    config = AutoConfig.from_pretrained(path, **init_kwargs)\n    if correct_gemma2 and isinstance(path, str) and \"gemma-2-2b-it\" in path:\n        # the EOS token in gemma2 is ambiguious, which may worsen RL performance.\n        # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a\n        logger.warning(\"Found gemma-2-2b-it tokenizer. Set eos_token and eos_token_id to <end_of_turn> and 107.\", stacklevel=1)\n        init_kwargs[\"eos_token\"] = \"<end_of_turn>\"\n        init_kwargs[\"eos_token_id\"] = 107\n    try:\n        if \"InternVL\" in config.architectures[0] and \"internlm2\" in config.llm_config.model_type:\n            from siirl.models.transformers.internvl_chat.tokenization_internlm2_fast import InternLM2TokenizerFast\n\n            tokenizer = InternLM2TokenizerFast.from_pretrained(\n                path,\n                use_fast=True,\n                split_special_tokens=model_args.split_special_tokens if model_args else False,\n                padding_side=\"right\",\n                **init_kwargs,\n            )\n        else:\n            tokenizer = AutoTokenizer.from_pretrained(\n                path,\n                use_fast=model_args.use_fast_tokenizer if model_args else True,\n                split_special_tokens=model_args.split_special_tokens if model_args else False,\n                padding_side=\"right\",\n                **init_kwargs,\n            )\n    except ValueError:  # try the fast one\n        tokenizer = AutoTokenizer.from_pretrained(\n            path,\n            use_fast=True,\n            padding_side=\"right\",\n            **init_kwargs,\n        )\n    except Exception as e:\n        raise OSError(\"Failed to load tokenizer.\") from e\n\n    if model_args:\n        patch_tokenizer(tokenizer, model_args, config)\n    try:\n        processor = AutoProcessor.from_pretrained(path, **init_kwargs)\n        if model_args:\n            patch_processor(processor, config, tokenizer, model_args)\n    except Exception as e:\n        logger.debug(f\"Processor was not found: {e}.\")\n        processor = None\n\n    # Avoid load tokenizer, see:\n    # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324\n    if processor is not None and \"Processor\" not in processor.__class__.__name__:\n        processor = None\n\n    if processor is None and \"InternVL\" in config.architectures[0]:\n        import torch\n        from siirl.models.transformers.internvl import build_transform, dynamic_preprocess\n\n        class InternVLProcessor:\n            def __init__(self, proc, config):\n                self.proc = proc\n                self.image_size = config.force_image_size or config.vision_config.image_size\n                self.patch_size = config.vision_config.patch_size\n                self.dynamic_image_size = False  # config.dynamic_image_size\n                self.min_dynamic_patch = config.min_dynamic_patch\n                self.max_dynamic_patch = config.max_dynamic_patch\n                self.use_thumbnail = config.use_thumbnail\n                self.num_image_token = int((self.image_size // self.patch_size) ** 2 * (config.downsample_ratio**2))\n\n            def process(self, images):\n                transform = self.proc(True, input_size=self.image_size)\n                pixel_values, image_flags = [], []\n                for image in images:\n                    pixel_values.append(transform(image))\n                    image_flags.append(torch.tensor([1] * 1, dtype=torch.long))\n                pixel_values = torch.stack(pixel_values)\n                image_flags = torch.stack(image_flags)\n                return {\"pixel_values\": pixel_values, \"image_flags\": image_flags}\n\n        processor = InternVLProcessor(build_transform, config)\n        logger.info(\"Build processor for model type \", config.model_type)\n\n    if correct_pad_token:\n        set_pad_token_id(tokenizer)\n\n    return {\"tokenizer\": tokenizer, \"processor\": processor}\n"
  },
  {
    "path": "siirl/models/mcore/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .registry import get_mcore_forward_fn, get_mcore_forward_fused_fn, get_mcore_weight_converter, hf_to_mcore_config, init_mcore_model\n\n__all__ = [\"hf_to_mcore_config\", \"init_mcore_model\", \"get_mcore_forward_fn\", \"get_mcore_forward_fused_fn\", \"get_mcore_weight_converter\"]\n"
  },
  {
    "path": "siirl/models/mcore/config_converter.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# convert huggingface config to mcore transformer config\n\n\nimport torch\nimport torch.nn.functional as F\nfrom megatron.core.transformer import MLATransformerConfig, TransformerConfig\nfrom transformers import PretrainedConfig\n\n\ndef _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> dict:\n    \"\"\"\n    Create a base TransformerConfig with common parameters across different model architectures.\n    TODO: (ycl) use dataclass or converter config?\n\n    Args:\n        hf_config: HuggingFace model configuration\n        dtype: Data type for the model\n        override_transformer_config_kwargs: Additional parameters to override defaults\n\n    Returns:\n        TransformerConfig with common parameters\n    \"\"\"\n    from megatron.core import parallel_state as mpu\n\n    # Common parallel state parameters\n    overlap_p2p_comm = mpu.get_virtual_pipeline_model_parallel_world_size() is not None and mpu.get_virtual_pipeline_model_parallel_world_size() > 1\n    batch_p2p_comm = False\n\n    # Base configuration with common parameters\n    base_config = {\n        # Model architecture parameters\n        \"num_layers\": hf_config.num_hidden_layers,\n        \"hidden_size\": hf_config.hidden_size,\n        \"num_attention_heads\": hf_config.num_attention_heads,\n        \"num_query_groups\": hf_config.num_key_value_heads,\n        \"ffn_hidden_size\": hf_config.intermediate_size,\n        \"attention_dropout\": hf_config.attention_dropout,\n        \"hidden_dropout\": getattr(hf_config, \"hidden_dropout\", 0.0),\n        \"kv_channels\": getattr(hf_config, \"head_dim\", None),\n        \"layernorm_epsilon\": hf_config.rms_norm_eps,\n        # Activation and normalization\n        \"activation_func\": F.silu,\n        \"normalization\": \"RMSNorm\",\n        \"gated_linear_unit\": True,\n        # Data types\n        \"pipeline_dtype\": dtype,\n        \"params_dtype\": dtype,\n        \"bf16\": dtype is torch.bfloat16,\n        # Parallel configuration\n        \"tensor_model_parallel_size\": mpu.get_tensor_model_parallel_world_size(),\n        \"pipeline_model_parallel_size\": mpu.get_pipeline_model_parallel_world_size(),\n        \"expert_model_parallel_size\": mpu.get_expert_model_parallel_world_size(),\n        \"expert_tensor_parallel_size\": mpu.get_expert_tensor_parallel_world_size(),\n        \"virtual_pipeline_model_parallel_size\": mpu.get_virtual_pipeline_model_parallel_world_size(),\n        \"context_parallel_size\": mpu.get_context_parallel_world_size(),\n        \"overlap_p2p_comm\": overlap_p2p_comm,\n        \"batch_p2p_comm\": batch_p2p_comm,\n        \"sequence_parallel\": mpu.get_tensor_model_parallel_world_size() > 1,\n        # Common settings\n        \"variable_seq_lengths\": True,\n        \"masked_softmax_fusion\": True,\n        \"moe_token_dispatcher_type\": \"alltoall\",\n    }\n\n    # Update with any provided overrides\n    base_config.update(override_transformer_config_kwargs)\n    print(f\"Overridden TF init config: {base_config}\")\n\n    return base_config\n\n\ndef hf_to_mcore_config_dense(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:\n    # for LlamaForCausalLM or Qwen2ForCausalLM\n    qkv_bias = True if \"Qwen2ForCausalLM\" in hf_config.architectures else getattr(hf_config, \"attention_bias\", False)\n    qk_layernorm = True if \"Qwen3ForCausalLM\" in hf_config.architectures else False\n\n    args = _get_base_transformer_config(hf_config=hf_config, dtype=dtype, use_cpu_initialization=False, add_bias_linear=False, add_qkv_bias=qkv_bias, qk_layernorm=qk_layernorm, **override_transformer_config_kwargs)\n    return TransformerConfig(**args)\n\n\ndef hf_to_mcore_config_qwen2moe(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:\n    args = _get_base_transformer_config(\n        hf_config=hf_config,\n        dtype=dtype,\n        use_cpu_initialization=False,\n        add_bias_linear=False,\n        layernorm_epsilon=hf_config.rms_norm_eps,\n        # MoE specific\n        moe_ffn_hidden_size=hf_config.moe_intermediate_size,\n        moe_router_bias_update_rate=0.001,\n        moe_router_topk=hf_config.num_experts_per_tok,\n        num_moe_experts=hf_config.num_experts,\n        moe_shared_expert_intermediate_size=hf_config.shared_expert_intermediate_size,\n        moe_aux_loss_coeff=hf_config.router_aux_loss_coef,\n        # moe_aux_loss_coeff=0.0,\n        moe_router_load_balancing_type=\"none\",  # turn off aux_loss as it hurts perf in RL\n        moe_shared_expert_overlap=True,\n        moe_grouped_gemm=True,\n        moe_router_score_function=\"softmax\",\n        # Other optimizations\n        persist_layer_norm=True,\n        bias_activation_fusion=True,\n        bias_dropout_fusion=True,\n        # Qwen specific\n        moe_router_pre_softmax=True,\n        add_qkv_bias=True,\n        **override_transformer_config_kwargs,\n    )\n    return TransformerConfig(**args)\n\n\ndef hf_to_mcore_config_mixtral(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:\n    args = _get_base_transformer_config(\n        hf_config=hf_config,\n        dtype=dtype,\n        use_cpu_initialization=False,\n        add_bias_linear=False,\n        layernorm_epsilon=hf_config.rms_norm_eps,\n        # MoE specific\n        num_moe_experts=hf_config.num_local_experts,\n        moe_aux_loss_coeff=hf_config.router_aux_loss_coef,\n        moe_router_topk=hf_config.num_experts_per_tok,\n        moe_router_pre_softmax=True,\n        moe_router_load_balancing_type=\"none\",  # turn off aux_loss as it hurts perf in RL\n        moe_router_score_function=\"softmax\",\n        moe_shared_expert_intermediate_size=None,  # mixtral has no shared expert\n        moe_shared_expert_overlap=False,  # mixtral has no shared expert\n        moe_ffn_hidden_size=hf_config.intermediate_size,\n        moe_router_bias_update_rate=0.001,\n        # moe_permute_fusion=True, # need TE 2.1+\n        moe_grouped_gemm=True,\n        # Other optimizations\n        persist_layer_norm=True,\n        apply_rope_fusion=True,\n        bias_activation_fusion=True,\n        bias_dropout_fusion=True,\n        **override_transformer_config_kwargs,\n    )\n    return TransformerConfig(**args)\n\n\ndef hf_to_mcore_config_qwen3moe(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:\n    args = _get_base_transformer_config(\n        hf_config=hf_config,\n        dtype=dtype,\n        use_cpu_initialization=False,\n        add_bias_linear=False,\n        layernorm_epsilon=hf_config.rms_norm_eps,\n        # MoE specific\n        moe_ffn_hidden_size=hf_config.moe_intermediate_size,\n        moe_router_bias_update_rate=0.001,\n        moe_router_topk=hf_config.num_experts_per_tok,\n        num_moe_experts=hf_config.num_experts,\n        moe_aux_loss_coeff=hf_config.router_aux_loss_coef,\n        # moe_aux_loss_coeff=0.0,\n        moe_router_load_balancing_type=\"none\",  # turn off aux_loss as it hurts perf in RL\n        moe_grouped_gemm=True,\n        moe_router_score_function=\"softmax\",\n        # Other optimizations\n        persist_layer_norm=True,\n        bias_activation_fusion=True,\n        bias_dropout_fusion=True,\n        # Qwen specific\n        moe_router_pre_softmax=False,\n        qk_layernorm=True,\n        **override_transformer_config_kwargs,\n    )\n    return TransformerConfig(**args)\n\n\ndef hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> MLATransformerConfig:\n    # DeepseekV3ForCausalLM\n    from megatron.core.transformer.enums import AttnBackend\n\n    from .patch_v012 import apply_patch\n\n    apply_patch()\n\n    mla_rope_config = {\n        \"beta_fast\": 32,\n        \"beta_slow\": 1,\n        \"factor\": 1,\n        \"mscale\": 1.0,\n        \"mscale_all_dim\": 1.0,\n        \"original_max_position_embeddings\": 4096,\n        \"type\": \"rope\",\n    }\n    if \"rope_scaling\" in hf_config and hf_config.rope_scaling is not None:\n        mla_rope_config.update(hf_config.rope_scaling)\n    moe_layer_freq = [1] * hf_config.num_hidden_layers\n    for i in range(min(hf_config.first_k_dense_replace, hf_config.num_hidden_layers)):\n        moe_layer_freq[i] = 0\n\n    # disable MTP and quantization for now\n    if \"num_nextn_predict_layers\" in hf_config:\n        assert hf_config.num_nextn_predict_layers == 0, \"MTP is not supported for now, please modify the config.json to set num_nextn_predict_layers to 0\"\n    assert \"quantization_config\" not in hf_config or not hf_config.quantization_config, \"quantization is not supported for now, please modify the config.json to remove quantization_config\"\n\n    args = _get_base_transformer_config(\n        hf_config=hf_config,\n        dtype=dtype,\n        use_cpu_initialization=False,\n        add_bias_linear=False,\n        attention_backend=AttnBackend.fused,\n        bf16=dtype is torch.bfloat16,\n        layernorm_epsilon=hf_config.rms_norm_eps,\n        ffn_hidden_size=hf_config.intermediate_size,\n        qk_layernorm=True,\n        # moe specific\n        moe_ffn_hidden_size=hf_config.moe_intermediate_size,\n        moe_token_dispatcher_type=\"alltoall\",\n        moe_router_bias_update_rate=0.001,\n        moe_router_enable_expert_bias=True,\n        moe_router_topk=hf_config.num_experts_per_tok,\n        num_moe_experts=hf_config.n_routed_experts,\n        moe_shared_expert_intermediate_size=hf_config.moe_intermediate_size * hf_config.n_shared_experts,\n        moe_aux_loss_coeff=getattr(hf_config, \"aux_loss_alpha\", 0.001),\n        moe_router_load_balancing_type=\"seq_aux_loss\",\n        moe_shared_expert_overlap=True,\n        # moe_permute_fusion=True, # need TE 2.1+\n        moe_grouped_gemm=True,\n        moe_router_score_function=\"sigmoid\",\n        moe_router_pre_softmax=True,\n        moe_router_topk_scaling_factor=hf_config.routed_scaling_factor,\n        moe_layer_freq=moe_layer_freq,\n        # MLA\n        q_lora_rank=hf_config.q_lora_rank,\n        kv_lora_rank=hf_config.kv_lora_rank,\n        qk_head_dim=hf_config.qk_nope_head_dim,\n        qk_pos_emb_head_dim=hf_config.qk_rope_head_dim,\n        v_head_dim=hf_config.v_head_dim,\n        rotary_base=hf_config.rope_theta,\n        rotary_scaling_factor=mla_rope_config[\"factor\"],\n        rope_type=mla_rope_config[\"type\"],\n        mscale=mla_rope_config[\"mscale\"],\n        mscale_all_dim=mla_rope_config[\"mscale_all_dim\"],\n        max_position_embeddings=mla_rope_config[\"original_max_position_embeddings\"],\n        beta_fast=mla_rope_config[\"beta_fast\"],\n        beta_slow=mla_rope_config[\"beta_slow\"],\n        # mcore 0.12 moe\n        moe_router_dtype=\"fp64\",\n        disable_bf16_reduced_precision_matmul=True,\n        # other\n        # deallocate_pipeline_outputs=True,\n        # gradient_accumulation_fusion=True,\n        persist_layer_norm=True,\n        bias_activation_fusion=True,\n        bias_dropout_fusion=True,\n    )\n    if override_transformer_config_kwargs:\n        args.update(override_transformer_config_kwargs)\n    transformer_config = MLATransformerConfig(**args)\n    # MTP\n    if \"num_nextn_predict_layers\" in hf_config:\n        transformer_config.mtp_num_layers = hf_config.num_nextn_predict_layers\n        transformer_config.mtp_loss_scaling_factor = 0.1\n\n    return transformer_config\n\n\ndef hf_to_mcore_config_qwen2_5_vl(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:\n    # Qwen2_5_VLForConditionalGeneration\n    raise NotImplementedError(\"Qwen2_5_VLForConditionalGeneration is not supported yet\")\n\n\ndef hf_to_mcore_config_llama4(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:\n    # Llama4ForConditionalGeneration\n    raise NotImplementedError(\"Llama4ForConditionalGeneration is not supported yet\")\n"
  },
  {
    "path": "siirl/models/mcore/loader.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\n\nfrom siirl.utils.extras.device import get_device_id, get_torch_device\n\nfrom .saver import _megatron_calc_global_rank\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, params_dtype, is_value_model=False):\n    \"\"\"Load merged state_dict to sharded Megatron module in training.\"\"\"\n    from megatron.core import DistributedDataParallel as LocalDDP\n    from megatron.core import mpu\n    from megatron.core.transformer.module import Float16Module\n    from torch.nn.parallel import DistributedDataParallel as torchDDP\n\n    from siirl.utils.megatron.megatron_utils import print_rank_0, unwrap_model\n\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    def broadcast_params(module):\n        for param in module.parameters():\n            torch.distributed.broadcast(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group())\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    cp_rank = mpu.get_context_parallel_rank()\n    src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=cp_rank)\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if torch.distributed.get_rank() == src_rank:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, (list, tuple)):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        gpt_model_module = _get_gpt_model(models[i])\n        assert len(gpt_model_module.decoder.layers) == num_layers_per_model\n\n    def _broadcast_tensor(tensor, name) -> torch.Tensor:\n        \"\"\"broadcast tensor from rank0 across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        if torch.distributed.get_rank() == src_rank:\n            if name in state_dict:\n                weight = state_dict[name]\n                tensor_shape = weight.shape\n            else:\n                tensor_shape = None\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not in state_dict, skip load\")\n            return\n\n        if tensor is None:\n            tensor = torch.empty(\n                tensor_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        if torch.distributed.get_rank() == src_rank:\n            tensor.data.copy_(weight)\n        dist.broadcast(tensor, src=src_rank, group=mp_group)\n\n    def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == src_rank:\n            if name in state_dict:\n                full_weight = state_dict[name]\n\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == src_rank:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=src_rank, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == src_rank:\n            if name in state_dict:\n                full_weight = state_dict[name]\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == src_rank:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=src_rank, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == src_rank:\n            gate_weight = state_dict[gate_name]\n            up_weight = state_dict[up_name]\n            new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id())\n            for i in range(tp_size):\n                intermediate_size_tp = config.intermediate_size // tp_size\n                gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0))\n\n            tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, f\"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}\"\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == src_rank:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=src_rank, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == src_rank:\n            assert q_name in state_dict and k_name in state_dict and v_name in state_dict\n            full_weight_q = state_dict[q_name]\n            full_weight_k = state_dict[k_name]\n            full_weight_v = state_dict[v_name]\n\n            hidden_size_per_head = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                sizes = [total_size * tp_size]\n                if not bias:\n                    sizes.append(config.hidden_size)\n                new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id())\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    num_query_groups_per_partition = models[0].config.num_query_groups // tp_size\n                    new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size]\n                    q_part_per_head = torch.chunk(q_part, num_query_groups_per_partition, dim=0)\n                    k_part_per_head = torch.chunk(k_part, num_query_groups_per_partition, dim=0)\n                    v_part_per_head = torch.chunk(v_part, num_query_groups_per_partition, dim=0)\n                    total_size_per_head = total_size // num_query_groups_per_partition\n                    for j in range(num_query_groups_per_partition):\n                        new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_(torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0))\n\n            else:\n                q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                sizes = [total_size * tp_size]\n                if not bias:\n                    sizes.append(config.hidden_size)\n                new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id())\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head\n                    end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head\n                    k_part = full_weight_k[start_idx:end_idx]\n                    v_part = full_weight_v[start_idx:end_idx]\n                    new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size]\n                    q_part_per_head = torch.chunk(q_part, config.num_attention_heads, dim=0)\n                    k_part_per_head = torch.chunk(k_part, config.num_attention_heads, dim=0)\n                    v_part_per_head = torch.chunk(v_part, config.num_attention_heads, dim=0)\n                    total_size_per_head = total_size // config.num_attention_heads\n                    for j in range(config.num_attention_heads):\n                        new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_(torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0))\n\n            tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, f\"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}\"\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == src_rank:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=src_rank, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    if dp_rank == 0:\n        # Embeddings\n        # -------------------\n        print_rank_0(\"loading embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        embed_tokens_weight = None\n        if pp_rank == 0:\n            embed_tokens_weight = gpt_model_module.embedding.word_embeddings.weight\n        _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, \"model.embed_tokens.weight\")\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n\n        for layer in range(config.num_hidden_layers):\n            print_rank_0(f\"loading layer #{layer}...\")\n            layer_name = f\"model.layers.{layer}\"\n            dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])\n            sync_layer = gpt_model_module.decoder.layers[dst_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.self_attention.linear_qkv.layer_norm_weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.input_layernorm.weight\",\n            )\n\n            if f\"{layer_name}.self_attn.q_norm.weight\" in state_dict:\n                _broadcast_tensor(\n                    sync_layer.self_attention.q_layernorm.weight if dst_pp_rank == pp_rank else None,\n                    f\"{layer_name}.self_attn.q_norm.weight\",\n                )\n                _broadcast_tensor(\n                    sync_layer.self_attention.k_layernorm.weight if dst_pp_rank == pp_rank else None,\n                    f\"{layer_name}.self_attn.k_norm.weight\",\n                )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attention.linear_qkv.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n            )\n            if f\"{layer_name}.self_attn.q_proj.bias\" in state_dict:\n                _broadcast_tp_shard_tensor_qkv(\n                    sync_layer.self_attention.linear_qkv.bias if dst_pp_rank == pp_rank else None,\n                    f\"{layer_name}.self_attn.q_proj.bias\",\n                    f\"{layer_name}.self_attn.k_proj.bias\",\n                    f\"{layer_name}.self_attn.v_proj.bias\",\n                    bias=True,\n                )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attention.linear_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                chunk_dim=1,\n            )\n            _broadcast_tensor(\n                sync_layer.mlp.linear_fc1.layer_norm_weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.linear_fc1.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.linear_fc2.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                chunk_dim=1,\n            )\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"loading final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.decoder.final_layernorm, \"weight\", None),\n            \"model.norm.weight\",\n        )\n\n        print_rank_0(\"loading lm_head...\")\n        lm_head_weight = None\n        if pp_rank + 1 == pp_size:\n            lm_head_weight = gpt_model_module.output_layer.weight\n\n        if is_value_model:\n            # if torch.distributed.get_rank() == src_rank:\n            if \"lm_head.weight\" in state_dict and state_dict[\"lm_head.weight\"].shape[0] == 1:\n                _broadcast_tensor(lm_head_weight, \"lm_head.weight\")\n            elif \"reward_head.weight\" in state_dict and state_dict[\"reward_head.weight\"].shape[0] == 1:\n                _broadcast_tensor(lm_head_weight, \"reward_head.weight\")\n                print_rank_0(\"load lm_head from value_head weight\")\n            else:\n                _broadcast_tensor(None, \"lm_head.weight\")\n                print_rank_0(\"fail to match lm_head in value_model\")\n            # else:\n\n            #     _broadcast_tensor(lm_head_weight, \"lm_head.weight\")\n\n        else:\n            _broadcast_tp_shard_tensor(lm_head_weight, \"lm_head.weight\")\n    dist.barrier()\n    # Broadcast weights inside data parallel groups\n    for wrapped_model in wrapped_models:\n        broadcast_params(wrapped_model)\n    pass\n    get_torch_device().empty_cache()\n    print_rank_0(f\"loading megatron ckpt done, time elapsed {time.time() - start_time}s\")\n"
  },
  {
    "path": "siirl/models/mcore/mbridge.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\ntry:\n    from mbridge import AutoBridge\n    from mbridge.utils.post_creation_callbacks import freeze_moe_router, make_value_model\nexcept ImportError:\n    print(\"mbridge package not found. Please install mbridge with `pip install verl[mcore]` or `pip install mbridge`\")\n    raise\n\n__all__ = [\"AutoBridge\", \"make_value_model\", \"freeze_moe_router\"]\n"
  },
  {
    "path": "siirl/models/mcore/model_forward.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom siirl.utils.megatron.megatron_utils import unwrap_model\n\nfrom .util import postprocess_packed_seqs, preprocess_packed_seqs, recover_left_padding, remove_left_padding\n\n\ndef gptmodel_forward(\n    model,\n    input_ids,\n    attention_mask,\n    position_ids,\n    sequence_parallel,\n    value_model=False,\n    pack_seqs=True,\n    logits_processor=None,\n    logits_processor_args: dict = None,\n    **kwargs,\n):  \n    \"\"\"Default forward pass for GPT models with optional sequence packing.\"\"\"\n    pre_process = unwrap_model(model).pre_process\n    post_process = unwrap_model(model).post_process\n    if pack_seqs:\n        batch_size, seq_len = attention_mask.shape[:2]\n        input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process)\n        input_ids_rmpad = input_ids_rmpad.contiguous()\n        output_orig = model(\n            input_ids=input_ids_rmpad,\n            attention_mask=None,\n            position_ids=position_ids,\n            packed_seq_params=packed_seq_params,\n        )\n        if post_process and logits_processor is not None:\n            args = {k: preprocess_packed_seqs(v, attention_mask, pre_process=True)[0] for k, v in logits_processor_args.items()}\n            output_dict = logits_processor(output_orig, **args)\n            output = {k: postprocess_packed_seqs(v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process) for k, v in output_dict.items()}\n        else:\n            output = postprocess_packed_seqs(output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process)\n    else:\n        assert logits_processor is None, \"logits_processor is not supported for non-packed sequence\"\n        batch_size, sequence_length = attention_mask.shape\n        new_input_ids, new_attention_mask, new_position_ids = remove_left_padding(input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process)\n        output = model(input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids)\n        output = recover_left_padding(output, new_attention_mask, attention_mask, sequence_length, post_process=post_process)\n    if value_model and post_process:\n        output = output[..., 0]\n    return output\n"
  },
  {
    "path": "siirl/models/mcore/model_forward_fused.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom collections import OrderedDict\nfrom typing import Optional\n\nimport torch\nfrom megatron.core import parallel_state\nfrom megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk\nfrom megatron.core.inference.contexts import BaseInferenceContext\nfrom megatron.core.models.gpt.gpt_model import GPTModel\nfrom megatron.core.packed_seq_params import PackedSeqParams\nfrom megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region\nfrom torch import Tensor\n\nfrom .util import preprocess_packed_seqs\nfrom siirl.utils.megatron.megatron_utils import unwrap_model\nfrom siirl.utils.model_utils.model import CausalLMOutputForPPO\n\nfrom .util import postprocess_packed_seqs_for_dict_output\n\n\ndef patch_fused_forward(model: torch.nn.Module):\n    model = unwrap_model(model)\n    if isinstance(model, GPTModel):\n        model = model\n    else:\n        raise ValueError(\"Model is not a GPTModel\")\n    model.forward_backup = model.forward\n    model.forward = _fused_GPTModel_forward.__get__(model, model.__class__)\n    return\n\n\ndef unpatch_fused_forward(model: torch.nn.Module):\n    model = unwrap_model(model)\n    if isinstance(model, GPTModel):\n        model = model\n    else:\n        raise ValueError(\"Model is not a GPTModel\")\n    model.forward = model.forward_backup\n    return\n\n\ndef fused_forward_gptmodel(\n    model: GPTModel,\n    input_ids: Tensor,\n    position_ids: Tensor,\n    attention_mask: Tensor,\n    labels: Tensor,\n    labels_mask: Tensor,\n    temperature: float = 1.0,\n    **kwargs,\n):\n    pre_process: bool = unwrap_model(model).pre_process\n    post_process: bool = unwrap_model(model).post_process\n\n    batch_size, seq_len = attention_mask.shape[:2]\n    input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process)\n    input_ids_rmpad = input_ids_rmpad.contiguous()\n    labels_rmpad, _ = preprocess_packed_seqs(labels, attention_mask, pre_process=True)\n    labels_mask_rmpad, _ = preprocess_packed_seqs(labels_mask, attention_mask, pre_process=True)\n    labels_rmpad = labels_rmpad.contiguous()\n    labels_mask_rmpad = labels_mask_rmpad.contiguous()\n\n    output_orig: CausalLMOutputForPPO = model(\n        input_ids=input_ids_rmpad,\n        attention_mask=None,\n        position_ids=position_ids,\n        labels=labels_rmpad,\n        packed_seq_params=packed_seq_params,\n        temperature=temperature,\n    )\n\n    if post_process:\n        # output_orig is in type of CausalLMOutputForPPO\n        output = postprocess_packed_seqs_for_dict_output(\n            labels_mask_rmpad,\n            output_orig,\n            packed_seq_params,\n            attention_mask,\n            batch_size,\n            seq_len,\n            post_process=post_process,\n        )\n    else:\n        output = output_orig\n    return output\n\n\ndef _fused_GPTModel_forward(\n    self,\n    input_ids: Tensor,\n    position_ids: Tensor,\n    attention_mask: Tensor,\n    decoder_input: Tensor = None,\n    labels: Tensor = None,\n    inference_context: BaseInferenceContext = None,\n    packed_seq_params: PackedSeqParams = None,\n    extra_block_kwargs: dict = None,\n    runtime_gather_output: Optional[bool] = None,\n    *,\n    inference_params: Optional[BaseInferenceContext] = None,\n    loss_mask: Optional[Tensor] = None,\n    temperature: float = 1.0,\n) -> CausalLMOutputForPPO:\n    \"\"\"\n    Forward pass for GPT models with fused kernel support.\n\n    Patch https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py\n    \"\"\"\n\n    # If decoder_input is provided (not None), then input_ids and position_ids are ignored.\n    # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.\n\n    # Decoder embedding.\n    if decoder_input is not None:\n        pass\n    elif self.pre_process:\n        decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)\n    else:\n        # intermediate stage of pipeline\n        # decoder will get hidden_states from encoder.input_tensor\n        decoder_input = None\n\n    # Rotary positional embeddings (embedding is None for PP intermediate devices)\n    rotary_pos_emb = None\n    rotary_pos_cos = None\n    rotary_pos_sin = None\n    if self.position_embedding_type == \"rope\" and not self.config.multi_latent_attention:\n        if not self.training and self.config.flash_decode and inference_context:\n            assert inference_context.is_static_batching(), \"GPTModel currently only supports static inference batching.\"\n            # Flash decoding uses precomputed cos and sin for RoPE\n            rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(\n                inference_context.max_sequence_length,\n                self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length),\n            )\n        else:\n            rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(\n                inference_context, self.decoder, decoder_input, self.config, packed_seq_params\n            )\n            rotary_pos_emb = self.rotary_pos_emb(\n                rotary_seq_len,\n                packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == \"thd\",\n            )\n    elif self.position_embedding_type == \"mrope\" and not self.config.multi_latent_attention:\n        if self.training or not self.config.flash_decode:\n            rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section)\n        else:\n            # Flash decoding uses precomputed cos and sin for RoPE\n            raise NotImplementedError(\n                \"Flash decoding uses precomputed cos and sin for RoPE, not implmented in MultimodalRotaryEmbedding yet.\"\n            )\n\n    if (\n        (self.config.enable_cuda_graph or self.config.flash_decode)\n        and rotary_pos_cos is not None\n        and inference_context\n        and inference_context.is_static_batching()\n        and not self.training\n    ):\n        sequence_len_offset = torch.tensor(\n            [inference_context.sequence_len_offset] * inference_context.current_batch_size,\n            dtype=torch.int32,\n            device=rotary_pos_cos.device,  # Co-locate this with the rotary tensors\n        )\n    else:\n        sequence_len_offset = None\n\n    # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the\n    # reference held by this caller function, enabling early garbage collection for\n    # skip inference\n\n    # Run decoder.\n    hidden_states = self.decoder(\n        hidden_states=decoder_input,\n        attention_mask=attention_mask,\n        inference_context=inference_context,\n        rotary_pos_emb=rotary_pos_emb,\n        rotary_pos_cos=rotary_pos_cos,\n        rotary_pos_sin=rotary_pos_sin,\n        packed_seq_params=packed_seq_params,\n        sequence_len_offset=sequence_len_offset,\n        **(extra_block_kwargs or {}),\n    )\n\n    # Process inference output.\n    if inference_context and not inference_context.is_static_batching():\n        hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1)\n\n    # logits and loss\n    output_weight = None\n    if self.share_embeddings_and_output_weights:\n        output_weight = self.shared_embedding_or_output_weight()\n\n    if self.mtp_process:\n        hidden_states = self.mtp(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            labels=labels,\n            loss_mask=loss_mask,\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            inference_params=inference_params,\n            rotary_pos_emb=rotary_pos_emb,\n            rotary_pos_cos=rotary_pos_cos,\n            rotary_pos_sin=rotary_pos_sin,\n            packed_seq_params=packed_seq_params,\n            sequence_len_offset=sequence_len_offset,\n            embedding=self.embedding,\n            output_layer=self.output_layer,\n            output_weight=output_weight,\n            runtime_gather_output=runtime_gather_output,\n            compute_language_model_loss=self.compute_language_model_loss,\n            **(extra_block_kwargs or {}),\n        )\n\n    if not self.post_process:\n        return hidden_states\n\n    output = CausalLMOutputForPPO(\n        loss=None,\n        logits=None,\n        past_key_values=None,\n        hidden_states=hidden_states,\n        attentions=None,\n    )\n\n    if self.config.sequence_parallel:\n        hidden_states = gather_from_sequence_parallel_region(hidden_states)\n    \n    from siirl.utils.kernel.linear_cross_entropy import linear_cross_entropy\n    logprobs, entropy = linear_cross_entropy(\n        hidden_states,\n        self.output_layer.weight,\n        labels,\n        temperature,\n        \"none\",\n        parallel_state.get_tensor_model_parallel_group(),\n    )\n\n    if has_config_logger_enabled(self.config):\n        payload = OrderedDict(\n            {\n                \"input_ids\": input_ids,\n                \"position_ids\": position_ids,\n                \"attention_mask\": attention_mask,\n                \"decoder_input\": decoder_input,\n                \"logprobs\": logprobs,\n                \"entropy\": entropy,\n            }\n        )\n        log_config_to_disk(self.config, payload, prefix=\"input_and_logits\")\n\n    output.entropy = entropy\n    output.log_probs = logprobs\n\n    return output\n"
  },
  {
    "path": "siirl/models/mcore/model_initializer.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# use mcore transformer config to initialize the model\nfrom abc import ABC, abstractmethod\n\nfrom megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, get_gpt_mtp_block_spec\nfrom megatron.core.models.gpt.gpt_model import GPTModel\n\nfrom .config_converter import PretrainedConfig, TransformerConfig\n\n\nclass BaseModelInitializer(ABC):\n    \"\"\"Base class for model initializers.\"\"\"\n\n    def __init__(self, tfconfig: TransformerConfig, hf_config: PretrainedConfig):\n        self.tfconfig = tfconfig\n        self.hf_config = hf_config\n\n    @abstractmethod\n    def get_transformer_layer_spec(self):\n        \"\"\"Get the transformer layer specification.\n        https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_layer_specs.py\"\"\"\n        pass\n\n    def get_rope_scaling_args(self) -> dict:\n        \"\"\"Get rope scaling args.\"\"\"\n        rope_scaling_args = {}\n        if \"rope_scaling\" in self.hf_config:\n            if self.hf_config.rope_scaling is not None:\n                # assert self.hf_config.rope_scaling[\"type\"] == \"linear\", \"only linear scaling is supported for now\"\n                rope_scaling_args[\"seq_len_interpolation_factor\"] = self.hf_config.rope_scaling[\"factor\"]\n        return rope_scaling_args\n\n    def initialize(\n        self,\n        pre_process: bool = True,\n        post_process: bool = True,\n        share_embeddings_and_output_weights: bool = False,\n        value: bool = False,\n        **extra_kwargs,\n    ) -> GPTModel:\n        \"\"\"Initialize a GPT model with the given configuration.\n        https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py\n\n        Args:\n            pre_process (bool): include embedding layer.\n            post_process (bool): including an output layer.\n            share_embeddings_and_output_weights (bool): input embeddings and output logit weights are shared.\n            value (bool): add an extra linear layer for classification or regression.\n\n        Returns:\n            GPTModel: An initialized GPT model instance\n        \"\"\"\n        transformer_layer_spec = self.get_transformer_layer_spec()\n        rope_scaling_args = self.get_rope_scaling_args()\n        mtp_block_spec = extra_kwargs.get(\"mtp_block_spec\", None)\n        model = GPTModel(\n            config=self.tfconfig,\n            transformer_layer_spec=transformer_layer_spec,\n            vocab_size=self.hf_config.vocab_size,\n            max_sequence_length=self.hf_config.max_position_embeddings,\n            pre_process=pre_process,\n            post_process=post_process,\n            share_embeddings_and_output_weights=share_embeddings_and_output_weights,\n            position_embedding_type=\"rope\",\n            rotary_base=self.hf_config.rope_theta,\n            **rope_scaling_args,\n            mtp_block_spec=mtp_block_spec,\n        )\n\n        if post_process and value:\n            from siirl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer\n\n            model.output_layer = LinearForLastLayer(input_size=self.tfconfig.hidden_size, output_size=1, config=self.tfconfig)\n\n        return model\n\n\nclass DenseModel(BaseModelInitializer):\n    \"\"\"Initializer for dense models like Llama and Qwen2.\"\"\"\n\n    def get_transformer_layer_spec(self):\n        assert self.tfconfig.normalization == \"RMSNorm\", \"only RMSNorm is supported for now\"\n        return get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)\n\n\nclass Qwen2MoEModel(BaseModelInitializer):\n    \"\"\"Initializer for Qwen2 MoE models.\"\"\"\n\n    def get_transformer_layer_spec(self):\n        assert self.tfconfig.normalization == \"RMSNorm\", \"only RMSNorm is supported for now\"\n        transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)\n\n        # Patch layer spec for shared experts\n        for i in range(len(transformer_layer_spec.layer_specs)):\n            transformer_layer_spec.layer_specs[i].submodules.mlp.submodules.shared_experts.params[\"gate\"] = True\n\n        return transformer_layer_spec\n\n    def initialize(self, **kwargs):\n        # Qwen default freeze_moe_router: true\n        model = super().initialize(**kwargs)\n        freeze_moe_router = kwargs.get(\"freeze_moe_router\", True)\n        if freeze_moe_router:\n            for layer in model.decoder.layers:\n                layer.mlp.router.weight.requires_grad = False\n        return model\n\n\nclass MixtralModel(BaseModelInitializer):\n    \"\"\"Initializer for Mixtral models.\"\"\"\n\n    def get_transformer_layer_spec(self):\n        assert self.tfconfig.normalization == \"RMSNorm\", \"only RMSNorm is supported for now\"\n        transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)\n        return transformer_layer_spec\n\n    def initialize(self, **kwargs):\n        model = super().initialize(**kwargs)\n        freeze_moe_router = kwargs.get(\"freeze_moe_router\", False)\n        if freeze_moe_router:\n            for layer in model.decoder.layers:\n                layer.mlp.router.weight.requires_grad = False\n        return model\n\n\nclass Qwen3MoEModel(BaseModelInitializer):\n    \"\"\"Initializer for Qwen3 MoE models.\"\"\"\n\n    def get_transformer_layer_spec(self):\n        assert self.tfconfig.normalization == \"RMSNorm\", \"only RMSNorm is supported for now\"\n        transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)\n        return transformer_layer_spec\n\n    def initialize(self, **kwargs):\n        # Qwen default freeze_moe_router: true\n        model = super().initialize(**kwargs)\n        freeze_moe_router = kwargs.get(\"freeze_moe_router\", True)\n        if freeze_moe_router:\n            for layer in model.decoder.layers:\n                layer.mlp.router.weight.requires_grad = False\n        return model\n\n\nclass DeepseekV3Model(BaseModelInitializer):\n    \"\"\"Initializer for DeepseekV3 models.\"\"\"\n\n    def get_transformer_layer_spec(self):\n        transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)\n        return transformer_layer_spec\n\n    def get_rope_scaling_args(self) -> dict:\n        \"\"\"Get rope scaling args.\"\"\"\n        rope_scaling_args = {}\n        return rope_scaling_args\n\n    def initialize(\n        self,\n        **kwargs,\n    ):\n        freeze_moe_router = kwargs.get(\"freeze_moe_router\", True)\n        if freeze_moe_router:\n            self.tfconfig.moe_router_load_balancing_type = \"none\"\n        # MTP\n        if self.tfconfig.mtp_num_layers is not None:\n            transformer_layer_spec = self.get_transformer_layer_spec()\n            mtp_block_spec = get_gpt_mtp_block_spec(self.tfconfig, transformer_layer_spec, use_transformer_engine=True)\n            kwargs[\"mtp_block_spec\"] = mtp_block_spec\n\n        model = super().initialize(**kwargs)\n        if freeze_moe_router:\n            for layer in model.decoder.layers:\n                if hasattr(layer.mlp, \"router\"):\n                    layer.mlp.router.weight.requires_grad = False\n        return model\n\n\nclass Qwen25VLModel(BaseModelInitializer):\n    \"\"\"Initializer for Qwen2.5 VL models.\"\"\"\n\n    def get_transformer_layer_spec(self):\n        raise NotImplementedError(\"VLM is not supported yet\")\n"
  },
  {
    "path": "siirl/models/mcore/patch_v012.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# there is some bug in mcore 0.12, so we need to patch it\n# 1. `get_query_key_value_tensors` in `multi_latent_attention.py` works wrong when packed_seq_params is not None\n\n\ndef apply_patch():\n    import torch\n    from megatron.core.transformer.multi_latent_attention import MLASelfAttention, apply_rotary_pos_emb, deprecate_inference_params, gather_from_sequence_parallel_region, gather_from_tensor_model_parallel_region, scatter_to_sequence_parallel_region\n    from megatron.core import parallel_state, tensor_parallel\n\n    def patch_get_query_key_value_tensors(\n        self,\n        hidden_states,\n        key_value_states=None,\n        position_ids=None,\n        packed_seq_params=None,\n        inference_context=None,\n        *,\n        inference_params=None,\n    ):\n        \"\"\"\n        Derives `query`, `key` and `value` tensors from `hidden_states`.\n        \"\"\"\n        # s = sequence length, b = batch size, h = hidden size, n = num attention heads\n        # Attention heads [s, b, n*h]\n        assert hidden_states.ndim == 3, f\"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D\"\n\n        inference_context = deprecate_inference_params(inference_context, inference_params)\n\n        # =========================================\n        # Prepare RoPE and seqlen related params\n        # =========================================\n        rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(inference_context, None, hidden_states, self.config, packed_seq_params)\n\n        # rotary_pos_emb:[s, b, 1, 64]\n        mscale = 1.0\n        if self.config.rope_type == \"rope\":\n            packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == \"thd\"\n            rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq)\n        else:\n            rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len)\n\n        # =========================================\n        # QKV down projection and layernorm\n        # =========================================\n        if self.config.q_lora_rank is not None:\n            # if linear_q_down_proj is ColumnParallelLinear:\n            #     q_compressed: [s, b, q_lora_rank / TP]\n            # elif linear_q_down_proj is Linear:\n            #     q_compressed: [s / TP, b, q_lora_rank]\n            q_compressed, _ = self.linear_q_down_proj(hidden_states)\n\n            # When output is sharded (ColumnParallelLinear), two things are needed to be\n            # identical to a normal Linear.\n            #   1. Manually gather output to restore output dim q_lora_rank;\n            #   2. Scatter sequence back to s / TP if sequence-parallel since it was\n            #      gathered by ColumnParallelLinear.\n            if q_compressed.size(-1) != self.config.q_lora_rank:\n                q_compressed = gather_from_tensor_model_parallel_region(q_compressed)\n                if self.config.sequence_parallel:\n                    q_compressed = scatter_to_sequence_parallel_region(q_compressed)\n\n            q_compressed = self.q_layernorm(q_compressed)\n        else:\n            q_compressed = hidden_states\n\n        # if linear_kv_down_proj is ColumnParallelLinear:\n        #     kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim) / TP]\n        # elif linear_kv_down_proj is Linear:\n        #     kv_combined: [s / TP, b, (kv_lora_rank + qk_pos_emb_head_dim)]\n        kv_combined, _ = self.linear_kv_down_proj(hidden_states)\n        if kv_combined.size(-1) != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim:\n            # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim)]\n            kv_combined = gather_from_tensor_model_parallel_region(kv_combined)\n            # kv_compressed:[s, b, kv_lora_rank], k_pos_emb: [s, b, qk_pos_emb_head_dim]\n            kv_compressed, k_pos_emb = torch.split(kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1)\n            if self.config.sequence_parallel:\n                # kv_compressed:[s / TP, b, kv_lora_rank]\n                kv_compressed = scatter_to_sequence_parallel_region(kv_compressed)\n        else:\n            # kv_compressed:[s / TP, b, kv_lora_rank], k_pos_emb: [s / TP, b, qk_pos_emb_head_dim]\n            kv_compressed, k_pos_emb = torch.split(kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1)\n            if parallel_state.get_tensor_model_parallel_world_size() > 1:\n                # k_pos_emb: [s, b, qk_pos_emb_head_dim]\n                k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb)\n\n        kv_compressed = self.kv_layernorm(kv_compressed)\n\n        # =========================================\n        # QKV up projection and RoPE apply\n        # =========================================\n        def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb):\n            if self.config.q_lora_rank is not None:\n                q, _ = self.linear_q_up_proj(q_compressed)\n            else:\n                # hidden_states:[s, b, 2048], q: [s, b, n * 192]\n                q, _ = self.linear_q_proj(q_compressed)\n\n            q_len, bsz, _ = q.size()\n\n            # q: [s, b, n, 192]\n            q = q.view(q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim)\n\n            # kv: [s, b, 2048]\n            kv, _ = self.linear_kv_up_proj(kv_compressed)\n\n            # kv: [s, b, n, 256]\n            kv = kv.view(\n                q_len,\n                bsz,\n                self.num_attention_heads_per_partition,\n                self.config.qk_head_dim + self.config.v_head_dim,\n            )\n\n            if inference_context is not None:\n                # add offset to the sequence start for inference\n                sequence_start = inference_context.sequence_len_offset\n                sequence_end = sequence_start + q_len\n                rotary_pos_emb = rotary_pos_emb[sequence_start:sequence_end]\n            else:\n                # Shorten rotary_pos_emb to the sequence length when inference_params\n                # is not provided. This makes sure we can run forward directly with\n                # any sequence length. During training, the sequence length is always\n                # the full rotary_pos_emb length.\n                rotary_pos_emb = rotary_pos_emb[0:q_len]\n\n            # [s, b, 64] -> [s, b, 1, 64]\n            k_pos_emb = torch.unsqueeze(k_pos_emb, 2)\n\n            # q: [s, b, n, 128], q_pos_emb: [s, b, n, 64]\n            q_no_pe, q_pos_emb = torch.split(q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1)\n\n            # k_no_pe: [s, b, n, 128], value: [s, b, n, 128]\n            k_no_pe, value = torch.split(kv, [self.config.qk_head_dim, self.config.v_head_dim], dim=-1)\n\n            if packed_seq_params is not None:\n                cu_seqlens_q = packed_seq_params.cu_seqlens_q\n                cu_seqlens_kv = packed_seq_params.cu_seqlens_kv\n                q_pos_emb = q_pos_emb.squeeze(1)\n                k_pos_emb = k_pos_emb.squeeze(1)\n                q_no_pe = q_no_pe.squeeze(1)\n                k_no_pe = k_no_pe.squeeze(1)\n                value = value.squeeze(1)\n            else:\n                cu_seqlens_q = cu_seqlens_kv = None\n\n            # q_pos_emb: [s, b, n, 64], k_pos_emb:[s, b, 1, 64]\n            q_pos_emb = apply_rotary_pos_emb(\n                q_pos_emb,\n                rotary_pos_emb,\n                config=self.config,\n                cu_seqlens=cu_seqlens_q,\n                mscale=mscale,\n            )\n            k_pos_emb = apply_rotary_pos_emb(\n                k_pos_emb,\n                rotary_pos_emb,\n                config=self.config,\n                cu_seqlens=cu_seqlens_kv,\n                mscale=mscale,\n            )\n\n            # query: [s, b, n, 192]\n            query = torch.cat([q_no_pe, q_pos_emb], dim=-1)\n            if packed_seq_params is not None:\n                k_pos_emb = k_pos_emb.expand(-1, self.num_attention_heads_per_partition, -1)\n                key = torch.cat([k_no_pe, k_pos_emb], dim=-1)\n            else:\n                # key: [s, b, n, 192]\n                k_pos_emb = k_pos_emb.expand(-1, -1, self.num_attention_heads_per_partition, -1)\n                key = torch.cat([k_no_pe, k_pos_emb], dim=-1)\n\n            query = query.contiguous()\n            key = key.contiguous()\n            value = value.contiguous()\n            return query, key, value\n\n        if self.recompute_up_proj:\n            self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput()\n            query, key, value = self.qkv_up_checkpoint.checkpoint(qkv_up_proj_and_rope_apply, q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb)\n        else:\n            query, key, value = qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb)\n\n        return query, key, value\n\n    MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors\n"
  },
  {
    "path": "siirl/models/mcore/registry.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nRegistry module for model architecture components.\n\"\"\"\n\nfrom enum import Enum\nfrom typing import Callable, Dict, Type\n\nimport torch\nimport torch.nn as nn\n\nfrom .config_converter import (\n    PretrainedConfig,\n    TransformerConfig,\n    hf_to_mcore_config_dense,\n    hf_to_mcore_config_dpskv3,\n    hf_to_mcore_config_llama4,\n    hf_to_mcore_config_mixtral,\n    hf_to_mcore_config_qwen2_5_vl,\n    hf_to_mcore_config_qwen2moe,\n    hf_to_mcore_config_qwen3moe,\n)\nfrom .model_forward import (\n    gptmodel_forward,\n)\nfrom .model_forward_fused import (\n    fused_forward_gptmodel,\n)\nfrom .model_initializer import (\n    BaseModelInitializer,\n    DeepseekV3Model,\n    DenseModel,\n    MixtralModel,\n    Qwen2MoEModel,\n    Qwen3MoEModel,\n    Qwen25VLModel,\n)\nfrom .weight_converter import (\n    McoreToHFWeightConverterDense,\n    McoreToHFWeightConverterDpskv3,\n    McoreToHFWeightConverterMixtral,\n    McoreToHFWeightConverterQwen2Moe,\n    McoreToHFWeightConverterQwen3Moe,\n)\n\n\nclass SupportedModel(Enum):\n    LLAMA = \"LlamaForCausalLM\"  # tested\n    QWEN2 = \"Qwen2ForCausalLM\"  # tested\n    QWEN2_MOE = \"Qwen2MoeForCausalLM\"  # pending\n    DEEPSEEK_V3 = \"DeepseekV3ForCausalLM\"  # not tested\n    MIXTRAL = \"MixtralForCausalLM\"  # tested\n    QWEN2_5_VL = \"Qwen2_5_VLForConditionalGeneration\"  # not supported\n    LLAMA4 = \"Llama4ForConditionalGeneration\"  # not tested\n    QWEN3 = \"Qwen3ForCausalLM\"  # tested\n    QWEN3_MOE = \"Qwen3MoeForCausalLM\"  # not tested\n\n\n# Registry for model configuration converters\nMODEL_CONFIG_CONVERTER_REGISTRY: Dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = {\n    SupportedModel.LLAMA: hf_to_mcore_config_dense,\n    SupportedModel.QWEN2: hf_to_mcore_config_dense,\n    SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe,\n    SupportedModel.DEEPSEEK_V3: hf_to_mcore_config_dpskv3,\n    SupportedModel.MIXTRAL: hf_to_mcore_config_mixtral,\n    SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl,\n    SupportedModel.LLAMA4: hf_to_mcore_config_llama4,\n    SupportedModel.QWEN3: hf_to_mcore_config_dense,\n    SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe,\n}\n\n# Registry for model initializers\nMODEL_INITIALIZER_REGISTRY: Dict[SupportedModel, Type[BaseModelInitializer]] = {\n    SupportedModel.LLAMA: DenseModel,\n    SupportedModel.QWEN2: DenseModel,\n    SupportedModel.QWEN2_MOE: Qwen2MoEModel,\n    SupportedModel.MIXTRAL: MixtralModel,\n    SupportedModel.DEEPSEEK_V3: DeepseekV3Model,\n    SupportedModel.QWEN2_5_VL: Qwen25VLModel,\n    SupportedModel.LLAMA4: DenseModel,\n    SupportedModel.QWEN3: DenseModel,\n    SupportedModel.QWEN3_MOE: Qwen3MoEModel,\n}\n\n# Registry for model forward functions\nMODEL_FORWARD_REGISTRY: Dict[SupportedModel, Callable] = {\n    SupportedModel.LLAMA: gptmodel_forward,\n    SupportedModel.QWEN2: gptmodel_forward,\n    SupportedModel.QWEN2_MOE: gptmodel_forward,\n    SupportedModel.MIXTRAL: gptmodel_forward,\n    SupportedModel.DEEPSEEK_V3: gptmodel_forward,\n    SupportedModel.QWEN2_5_VL: gptmodel_forward,\n    SupportedModel.LLAMA4: gptmodel_forward,\n    SupportedModel.QWEN3: gptmodel_forward,\n    SupportedModel.QWEN3_MOE: gptmodel_forward,\n    SupportedModel.DEEPSEEK_V3: gptmodel_forward,\n}\n\n# Registry for model forward functions\nMODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = {\n    SupportedModel.LLAMA: fused_forward_gptmodel,\n    SupportedModel.QWEN2: fused_forward_gptmodel,\n    SupportedModel.QWEN2_MOE: fused_forward_gptmodel,\n    SupportedModel.MIXTRAL: fused_forward_gptmodel,\n    SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel,\n    SupportedModel.LLAMA4: fused_forward_gptmodel,\n    SupportedModel.QWEN3: fused_forward_gptmodel,\n    SupportedModel.QWEN3_MOE: fused_forward_gptmodel,\n    SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel,\n}\n\n# Registry for model weight converters\nMODEL_WEIGHT_CONVERTER_REGISTRY: Dict[SupportedModel, Type] = {\n    SupportedModel.LLAMA: McoreToHFWeightConverterDense,\n    SupportedModel.QWEN2: McoreToHFWeightConverterDense,\n    SupportedModel.QWEN2_MOE: McoreToHFWeightConverterQwen2Moe,\n    SupportedModel.MIXTRAL: McoreToHFWeightConverterMixtral,\n    SupportedModel.DEEPSEEK_V3: McoreToHFWeightConverterDpskv3,\n    SupportedModel.QWEN3: McoreToHFWeightConverterDense,\n    SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe,\n}\n\n\ndef get_supported_model(model_type: str) -> SupportedModel:\n    try:\n        return SupportedModel(model_type)\n    except ValueError as err:\n        supported_models = [e.value for e in SupportedModel]\n        raise NotImplementedError(f\"Model Type: {model_type} not supported. Supported models: {supported_models}\") from err\n\n\ndef hf_to_mcore_config(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:\n    assert len(hf_config.architectures) == 1, \"Only one architecture is supported for now\"\n    model = get_supported_model(hf_config.architectures[0])\n    return MODEL_CONFIG_CONVERTER_REGISTRY[model](hf_config, dtype, **override_transformer_config_kwargs)\n\n\ndef init_mcore_model(\n    tfconfig: TransformerConfig,\n    hf_config: PretrainedConfig,\n    pre_process: bool = True,\n    post_process: bool = None,\n    *,\n    share_embeddings_and_output_weights: bool = False,\n    value: bool = False,\n    **extra_kwargs,  # may be used for vlm and moe\n) -> nn.Module:\n    \"\"\"\n    Initialize a Mcore model.\n\n    Args:\n        tfconfig: The transformer config.\n        hf_config: The HuggingFace config.\n        pre_process: Optional pre-processing function.\n        post_process: Optional post-processing function.\n        share_embeddings_and_output_weights: Whether to share embeddings and output weights.\n        value: Whether to use value.\n        **extra_kwargs: Additional keyword arguments.\n\n    Returns:\n        The initialized model.\n    \"\"\"\n    assert len(hf_config.architectures) == 1, \"Only one architecture is supported for now\"\n    model = get_supported_model(hf_config.architectures[0])\n    initializer_cls = MODEL_INITIALIZER_REGISTRY[model]\n    initializer = initializer_cls(tfconfig, hf_config)\n    return initializer.initialize(pre_process=pre_process, post_process=post_process, share_embeddings_and_output_weights=share_embeddings_and_output_weights, value=value, **extra_kwargs)\n\n\ndef get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable:\n    \"\"\"\n    Get the forward function for given model architecture.\n    \"\"\"\n    assert len(hf_config.architectures) == 1, \"Only one architecture is supported for now\"\n    model = get_supported_model(hf_config.architectures[0])\n    return MODEL_FORWARD_REGISTRY[model]\n\ndef get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable:\n    \"\"\"\n    Get the forward function for given model architecture.\n    \"\"\"\n    assert len(hf_config.architectures) == 1, \"Only one architecture is supported for now\"\n    model = get_supported_model(hf_config.architectures[0])\n    return MODEL_FORWARD_FUSED_REGISTRY[model]\n\ndef get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable:\n    \"\"\"\n    Get the weight converter for given model architecture.\n    \"\"\"\n    assert len(hf_config.architectures) == 1, \"Only one architecture is supported for now\"\n    model = get_supported_model(hf_config.architectures[0])\n    tfconfig = hf_to_mcore_config(hf_config, dtype)\n    return MODEL_WEIGHT_CONVERTER_REGISTRY[model](hf_config, tfconfig)\n"
  },
  {
    "path": "siirl/models/mcore/saver.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\nfrom megatron.core import mpu\nfrom megatron.core.distributed import DistributedDataParallel as LocalDDP\nfrom megatron.core.transformer.module import Float16Module\nfrom torch.nn.parallel import DistributedDataParallel as torchDDP\n\nfrom siirl.utils.extras.device import get_device_id, get_torch_device\nfrom siirl.utils.megatron.megatron_utils import print_rank_0, unwrap_model\n\n\ndef _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0):\n    \"\"\"Calculate global rank with support for CP/EP parallelism\"\"\"\n\n    # Get parallel sizes for each dimension\n    tp_size = mpu.get_tensor_model_parallel_world_size()\n    dp_size = mpu.get_data_parallel_world_size()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    cp_size = mpu.get_context_parallel_world_size()\n    # ep_size = mpu.get_expert_model_parallel_world_size()\n\n    # Verify total GPU count matches (must be consistent with parallel_state.py)\n    total_size = tp_size * dp_size * pp_size * cp_size\n    assert total_size == torch.distributed.get_world_size(), f\"{tp_size}x{dp_size}x{pp_size}x{cp_size} != {torch.distributed.get_world_size()}\"\n\n    # Core calculation logic (corresponds to RankGenerator order parameter)\n    # Assumes default order is \"tp-cp-ep-dp-pp\"\n    return ((pp_rank * dp_size + dp_rank) * cp_size + cp_rank) * tp_size + tp_rank\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):\n    \"\"\"Merge sharded parameters of a Megatron module into a merged checkpoint.\n\n    Args:\n        wrapped_models (list of megatron.core.distributed.DistributedDataParallel):\n            The local DDP wrapped megatron modules.\n        config (str or None):\n            HF config for model\n        dtype: model params type\n        is_value_model: if model is value model\n        tie_word_embeddings: tie_word_embeddings\n    Returns:\n        state_dict (dict):\n            The merged state_dict in rank 0, and an empty dictionary in other ranks.\n    \"\"\"\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    cp_rank = mpu.get_context_parallel_rank()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if dist.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, (list, tuple)):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        assert len(models[i].decoder.layers) == num_layers_per_model, \"len model layers {} not equal to num_layers_per_model {}\".format(len(models[i].decoder.layers), num_layers_per_model)\n\n    state_dict = dict()\n\n    def _get_cpu_tensor(tensor: torch.Tensor):\n        if tensor is None:\n            return None\n        if tensor.device == torch.device(\"cpu\"):\n            return tensor.detach().clone()\n        return tensor.detach().cpu()\n\n    def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n\n        if torch.distributed.get_rank() == src_rank:\n            if tensor is None:\n                weight = None\n                tensor_shape = None\n            else:\n                weight = tensor\n                tensor_shape = weight.shape\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not exist, skip collect\")\n            return\n\n        if weight is None:\n            weight = torch.empty(\n                tensor_shape,\n                dtype=dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n\n        dist.broadcast(weight, src=src_rank, group=mp_group)\n\n        if torch.distributed.get_rank() == 0:\n            state_dict[name] = _get_cpu_tensor(weight)\n\n    def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        # tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=concat_dim)\n            if mutate_func is not None:\n                full_tensor = mutate_func(full_tensor)\n            state_dict[name] = full_tensor\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        # tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            intermediate_size_tp = config.intermediate_size // tp_size\n            gate_weight_list = []\n            up_weight_list = []\n            for i in range(tp_size):\n                gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]\n                gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]\n                up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]\n                gate_weight_list.append(gate_weight_tp)\n                up_weight_list.append(up_weight_tp)\n\n            state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)\n            state_dict[up_name] = torch.cat(up_weight_list, dim=0)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        # tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            q_weight_list = []\n            k_weight_list = []\n            v_weight_list = []\n            hidden_size_per_head = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_size_chunk = q_size_tp // num_query_groups_per_partition\n                    kv_size_chunk = kv_size_tp // num_query_groups_per_partition\n                    for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):\n                        q_part = qkv_part_chunk[:q_size_chunk]\n                        k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]\n                        v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]\n                        q_weight_list.append(q_part)\n                        k_weight_list.append(k_part)\n                        v_weight_list.append(v_part)\n            else:\n                q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_size_chunk = q_size_tp // num_query_groups_per_partition\n                    kv_size_chunk = kv_size_tp // num_query_groups_per_partition\n                    for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):\n                        q_part = qkv_part_chunk[:q_size_chunk]\n                        k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]\n                        v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]\n                        q_weight_list.append(q_part)\n                        if i * config.num_key_value_heads % tp_size == 0:\n                            k_weight_list.append(k_part)\n                            v_weight_list.append(v_part)\n\n            state_dict[q_name] = torch.cat(q_weight_list, dim=0)\n            state_dict[k_name] = torch.cat(k_weight_list, dim=0)\n            state_dict[v_name] = torch.cat(v_weight_list, dim=0)\n\n    # empty cache before collecting weights\n    get_torch_device().empty_cache()\n    # Embeddings\n    # -------------------\n    if dp_rank == 0 and cp_rank == 0:  # models are identical across cp ranks\n        # Embeddings\n        # -------------------\n        print_rank_0(\"collecting embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        _broadcast_tp_shard_tensor(\n            gpt_model_module.embedding.word_embeddings.weight if pp_rank == 0 else None,\n            \"model.embed_tokens.weight\",\n            src_pp_rank=0,\n        )\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n        for layer in range(config.num_hidden_layers):\n            print_rank_0(f\"collecting layer #{layer}...\")\n            layer_name = f\"model.layers.{layer}\"\n            src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])\n            sync_layer = gpt_model_module.decoder.layers[src_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.self_attention.linear_qkv.layer_norm_weight,\n                f\"{layer_name}.input_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            if gpt_model_module.config.qk_layernorm:\n                _broadcast_tensor(\n                    sync_layer.self_attention.q_layernorm.weight,\n                    f\"{layer_name}.self_attn.q_norm.weight\",\n                    src_pp_rank=src_pp_rank,\n                )\n                _broadcast_tensor(\n                    sync_layer.self_attention.k_layernorm.weight,\n                    f\"{layer_name}.self_attn.k_norm.weight\",\n                    src_pp_rank=src_pp_rank,\n                )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attention.linear_qkv.weight,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            if gpt_model_module.config.add_qkv_bias:\n                _broadcast_tp_shard_tensor_qkv(\n                    sync_layer.self_attention.linear_qkv.bias,\n                    f\"{layer_name}.self_attn.q_proj.bias\",\n                    f\"{layer_name}.self_attn.k_proj.bias\",\n                    f\"{layer_name}.self_attn.v_proj.bias\",\n                    src_pp_rank=src_pp_rank,\n                )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attention.linear_proj.weight,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tensor(\n                sync_layer.mlp.linear_fc1.layer_norm_weight,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.linear_fc1.weight,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.linear_fc2.weight,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"collecting final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.decoder.final_layernorm, \"weight\", None),\n            \"model.norm.weight\",\n            src_pp_rank=pp_size - 1,\n        )\n\n        if tie_word_embeddings:\n            print_rank_0(\"tie word embedding skip load lm_head...\")\n        else:\n            print_rank_0(\"collecting lm_head...\")\n\n            if is_value_model:\n                lm_head_weight = None\n                if pp_rank == pp_size - 1:\n                    lm_head_weight = getattr(gpt_model_module.output_layer, \"weight\", None)\n                _broadcast_tensor(lm_head_weight, \"lm_head.weight\", src_pp_rank=pp_size - 1)\n\n            else:\n                _broadcast_tp_shard_tensor(\n                    getattr(gpt_model_module.output_layer, \"weight\", None) if pp_rank == pp_size - 1 else None,\n                    \"lm_head.weight\",\n                    src_pp_rank=pp_size - 1,\n                )\n\n    dist.barrier()\n    get_torch_device().empty_cache()\n    if torch.distributed.get_rank() == 0:\n        for k, v in state_dict.items():\n            if dtype != v.dtype:\n                state_dict[k] = v.to(dtype)\n\n    print_rank_0(f\"merge megatron ckpt done, time elapsed {time.time() - start_time}s\")\n    return state_dict\n\n\ndef merge_megatron_ckpt_gptmodel_qwen_moe(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):\n    raise NotImplementedError(\"merge_megatron_ckpt_gptmodel_qwen_moe is not implemented\")\n\n\ndef merge_megatron_ckpt_gptmodel_dpskv3(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):\n    raise NotImplementedError(\"merge_megatron_ckpt_gptmodel_dpskv3 is not implemented\")\n\n\ndef merge_megatron_ckpt_gptmodel_mixtral(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):\n    raise NotImplementedError(\"merge_megatron_ckpt_gptmodel_mixtral is not implemented\")\n"
  },
  {
    "path": "siirl/models/mcore/util.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom megatron.core import parallel_state as mpu\nfrom megatron.core.packed_seq_params import PackedSeqParams\n\nfrom siirl.utils.model_utils.model import CausalLMOutputForPPO\n\n\ndef preprocess_packed_seqs(\n    input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True\n) -> tuple[torch.Tensor, PackedSeqParams]:\n    \"\"\"\n    Preprocess packed sequences\n    CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1\n    gets second and second last chunks, and so on), this is for load balancing with causal masking.\n    See https://github.com/NVIDIA/TransformerEngine/issues/1368\n    \"\"\"\n    batch_size = input_ids.shape[0]\n\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    tp_size = mpu.get_tensor_model_parallel_world_size()\n    cp_size = mpu.get_context_parallel_world_size()\n    cp_rank = mpu.get_context_parallel_rank()\n    align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size\n\n    pad_size = (align_size - seqlens_in_batch % align_size) % align_size\n    seqlens_in_batch_padded = seqlens_in_batch + pad_size\n\n    cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)\n    cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0)\n    cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)\n    cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0)\n\n    # ----------------------------------------------------------------------------\n    # Move the index information needed in the subsequent loop to the CPU at once,\n    # to avoid frequent .item() calls in the loop that cause D2H synchronization\n    # ----------------------------------------------------------------------------\n    seqlens_in_batch_cpu: list[int] = seqlens_in_batch.tolist()  # original valid lengths\n    seqlens_in_batch_padded_cpu: list[int] = seqlens_in_batch_padded.tolist()  # lengths after padding\n    cu_seqlens_padded_cpu: list[int] = cu_seqlens_padded.tolist()  # start positions (after padding)\n\n    # Pure Python int calculation to avoid further synchronization\n    max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu)\n\n    shape = list(input_ids.shape[1:])\n    shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size\n    if pre_process:\n        input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device)\n        for i in range(batch_size):\n            # Use Python int, so no GPU→CPU sync in the loop\n            if cp_size <= 1:\n                seqlen = seqlens_in_batch_cpu[i]\n                start_idx = cu_seqlens_padded_cpu[i]\n                input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i, attention_mask[i]]\n                continue\n\n            seqlen_padded_i = seqlens_in_batch_padded_cpu[i]\n            seqlen = seqlen_padded_i // cp_size\n            half_seqlen = seqlen // 2\n            start_idx = cu_seqlens_padded_cpu[i] // cp_size\n            # split to 2 chunks\n            d = input_ids[i, attention_mask[i]]\n            input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[\n                half_seqlen * cp_rank : half_seqlen * (cp_rank + 1)\n            ]\n\n            remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1)\n            remain_end = seqlen_padded_i - half_seqlen * cp_rank\n            remain_end = min(remain_end, d.shape[0])\n            remain_len = remain_end - remain_start\n            if remain_len > 0:\n                input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[\n                    remain_start:remain_end\n                ]\n\n    packed_seq_params = PackedSeqParams(\n        qkv_format=\"thd\",\n        cu_seqlens_q=cu_seqlens_padded,\n        max_seqlen_q=max_seqlen_in_batch,\n        cu_seqlens_kv=cu_seqlens_padded,\n        max_seqlen_kv=max_seqlen_in_batch,\n        cu_seqlens_q_padded=cu_seqlens_padded,\n        cu_seqlens_kv_padded=cu_seqlens_padded,\n    )\n    if pre_process:\n        return input_ids_rmpad.unsqueeze(0), packed_seq_params\n    else:\n        return input_ids, packed_seq_params\n\n\ndef postprocess_packed_seqs(\n    output: torch.Tensor,\n    packed_seq_params: PackedSeqParams,\n    attention_mask: torch.Tensor,\n    batch_size: int,\n    seq_len: int,\n    post_process: bool = True,\n) -> torch.Tensor:\n    \"\"\"\n    Postprocess packed sequences\n    \"\"\"\n    if not post_process:\n        return output\n\n    # -------------------------------------------------------------------------\n    # Move the lengths and offsets needed for subsequent Python-level indexing to the CPU in advance,\n    # to avoid a large number of .item() calls in the loop\n    # -------------------------------------------------------------------------\n    cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist()\n    seq_lens_cpu: list[int] = attention_mask.sum(dim=1, dtype=torch.int32).cpu().tolist()\n\n    shape = [batch_size, seq_len] + list(output.shape[2:])  # 1,packed, dim -> batch_size, seq_len, dim\n    output_new = torch.zeros(shape, dtype=output.dtype, device=output.device)\n\n    cp_size = mpu.get_context_parallel_world_size()\n    # all gather output across context parallel group\n    if cp_size > 1:\n        # output shape: [1, packed_len, hidden_dim]\n        # need to gather across cp group and concatenate in sequence dimension\n        output_list = [torch.empty_like(output) for _ in range(cp_size)]\n        torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group())\n        output_list[mpu.get_context_parallel_rank()] = output\n    else:\n        output_list = [output]\n    for i in range(batch_size):\n        if cp_size <= 1:\n            s = seq_lens_cpu[i]\n            start_idx = cu_padded_cpu[i]\n            output_new[i, attention_mask[i]] = output[0][start_idx : start_idx + s]\n            continue\n        s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size\n        half_seqlen = s_len_padded_chunk // 2\n        s_len = seq_lens_cpu[i]\n        s_len_padded = s_len_padded_chunk * cp_size\n        tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device)\n        for j in range(cp_size):\n            o = output_list[j][0]\n            # split to 2 chunks\n            packed_start_idx = cu_padded_cpu[i] // cp_size\n            o0, o1 = (\n                o[packed_start_idx : packed_start_idx + half_seqlen],\n                o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk],\n            )\n            tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0\n            tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1\n        output_new[i, attention_mask[i]] = tmp[:s_len]\n\n    return output_new\n\n\ndef remove_left_padding(\n    input_ids: torch.Tensor,\n    attention_mask: torch.Tensor,\n    position_ids: torch.Tensor,\n    sequence_parallel: bool = False,\n    pre_process: bool = True,\n):\n    \"\"\"\n    Remove left padding from input_ids, attention_mask and position_ids\n    return new_input_ids, new_attention_mask, new_position_ids\n    \"\"\"\n    assert attention_mask.ndim == 2\n    assert position_ids.ndim == 2\n    cp_size = mpu.get_context_parallel_world_size()\n    assert cp_size == 1, \"Context parallel size without seq_pack is not supported\"\n    batch_size = input_ids.shape[0]\n    shape = list(input_ids.shape)  # batch_size, seq_len,...\n    seq_lens = attention_mask.sum(dim=1)\n    seq_len = seq_lens.max().item()\n    if sequence_parallel:\n        sp_world_size = mpu.get_tensor_model_parallel_world_size()\n        pad_size = (sp_world_size - seq_len % sp_world_size) % sp_world_size\n        seq_len = seq_len + pad_size\n    shape[1] = seq_len\n    if pre_process:\n        new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape)\n    new_attention_mask = torch.zeros(\n        dtype=attention_mask.dtype, device=attention_mask.device, size=(batch_size, seq_len)\n    )\n    new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len))\n    for i in range(batch_size):\n        if pre_process:\n            new_input_ids[i, : seq_lens[i]] = input_ids[i, attention_mask[i]]\n        new_attention_mask[i, : seq_lens[i]] = attention_mask[i, attention_mask[i]]\n        new_position_ids[i, : seq_lens[i]] = position_ids[i, attention_mask[i]]\n    if pre_process:\n        return new_input_ids, new_attention_mask, new_position_ids\n    else:\n        return input_ids, new_attention_mask, new_position_ids\n\n\ndef recover_left_padding(\n    result,\n    attention_mask: torch.Tensor,\n    original_attention_mask: torch.Tensor,\n    origin_seqlen: int,\n    post_process: bool = True,\n):\n    \"\"\"\n    Recover left padding from result\n    return result\n    \"\"\"\n    if not post_process:\n        return result\n    shape = list(result.shape)\n    batch_size = shape[0]\n    shape[1] = origin_seqlen\n    new_result = torch.zeros(dtype=result.dtype, device=result.device, size=shape)\n    for i in range(batch_size):\n        new_result[i, original_attention_mask[i]] = result[i, attention_mask[i]]\n    return new_result\n\n\ndef postprocess_packed_seqs_for_dict_output(\n    labels_mask: torch.Tensor,\n    output: CausalLMOutputForPPO,\n    packed_seq_params: PackedSeqParams,\n    attention_mask: torch.Tensor,\n    batch_size: int,\n    seq_len: int,\n    post_process: bool = True,\n) -> dict[str, torch.Tensor]:\n    \"\"\"_summary_\n    For fused kernels, the output is a dictionary with keys like 'log_probs', 'entropy', etc.\n    This function post-processes each tensor in the output dictionary.\n    Args:\n        output (CausalLMOutputForPPO): _description_\n        packed_seq_params (PackedSeqParams): _description_\n        attention_mask (torch.Tensor): _description_\n        batch_size (int): _description_\n        seq_len (int): _description_\n        post_process (bool, optional): _description_. Defaults to True.\n    Returns:\n        CausalLMOutputForPPO: _description_\n    \"\"\"\n    ret = {}\n    output.entropy = output.entropy.view(1, -1)\n    output.log_probs = output.log_probs.view(1, -1)\n    output.log_probs = output.log_probs.masked_fill(~labels_mask, 0.0)\n    ret[\"entropy\"] = postprocess_packed_seqs(\n        output.entropy, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process\n    )\n    ret[\"log_probs\"] = postprocess_packed_seqs(\n        output.log_probs, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process\n    )\n    return ret\n"
  },
  {
    "path": "siirl/models/mcore/weight_converter.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# online convert mcore weight to pure huggingface weight, no any fusion\n# including format conversion and name mapping\n# not including resharding\nimport torch\nfrom megatron.core.transformer import TransformerConfig\nfrom transformers import PretrainedConfig\n\n\nclass McoreToHFWeightConverterBase:\n    def __init__(self, hf_config: PretrainedConfig, mcore_config: TransformerConfig):\n        self.hf_config = hf_config\n        self.mcore_config = mcore_config\n\n    def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> torch.Tensor:\n        raise NotImplementedError\n\n\nclass McoreToHFWeightConverterDense(McoreToHFWeightConverterBase):\n    def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # 'decoder.layers.0.self_attention.linear_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_qkv.layer_norm_weight'\n        # 'decoder.layers.0.self_attention.linear_qkv.weight'\n        # 'decoder.layers.0.self_attention.linear_qkv.bias'\n        layer_number = name.split(\".\")[2]\n        convert_names = []\n        if \"self_attention.linear_qkv.bias\" in name or \"self_attention.linear_qkv.weight\" in name:\n            param_type = name.split(\".\")[-1]\n            assert param_type == \"bias\" or param_type == \"weight\"\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.q_proj.{param_type}\")\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.k_proj.{param_type}\")\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.v_proj.{param_type}\")\n            assert len(params) == 3\n        elif \"self_attention.linear_proj.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.o_proj.weight\")\n            assert len(params) == 1\n        elif \"self_attention.linear_qkv.layer_norm_weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.input_layernorm.weight\")\n            assert len(params) == 1\n        elif \"self_attention.q_layernorm.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.q_norm.weight\")\n            assert len(params) == 1\n        elif \"self_attention.k_layernorm.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.self_attn.k_norm.weight\")\n            assert len(params) == 1\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n\n    def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight'\n        # 'decoder.layers.0.mlp.linear_fc1.weight'\n        # 'decoder.layers.0.mlp.linear_fc2.weight'\n        layer_number = name.split(\".\")[2]\n        convert_names = []\n        if \"mlp.linear_fc1.weight\" in name:\n            # split gate_proj and up_proj\n            convert_names.append(f\"model.layers.{layer_number}.mlp.gate_proj.weight\")\n            convert_names.append(f\"model.layers.{layer_number}.mlp.up_proj.weight\")\n            assert len(params) == 2\n        elif \"mlp.linear_fc1.layer_norm_weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.post_attention_layernorm.weight\")\n            assert len(params) == 1\n        elif \"mlp.linear_fc2.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.mlp.down_proj.weight\")\n            assert len(params) == 1\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n\n    def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        direct_name_mapping = {\n            \"embedding.word_embeddings.weight\": \"model.embed_tokens.weight\",\n            \"decoder.final_layernorm.weight\": \"model.norm.weight\",\n            \"output_layer.weight\": \"lm_head.weight\",\n        }\n        if name in direct_name_mapping:\n            return [direct_name_mapping[name]], [params_one_group[0]]\n\n        if \"self_attention\" in name:\n            return self._convert_attention_param(name, params_one_group)\n        elif \"mlp\" in name:\n            return self._convert_mlp_param(name, params_one_group)\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n\n\nclass McoreToHFWeightConverterQwen2Moe(McoreToHFWeightConverterDense):\n    def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # 'decoder.layers.0.pre_mlp_layernorm.weight',\n        # 'decoder.layers.0.mlp.router.weight',\n        # 'decoder.layers.0.mlp.shared_experts.gate_weight',\n        # 'decoder.layers.0.mlp.shared_experts.linear_fc1.weight',\n        # 'decoder.layers.0.mlp.shared_experts.linear_fc2.weight'\n        # moe1\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight0',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight1',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight2',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight3',\n        # moe2\n        # 'decoder.layers.0.mlp.experts.linear_fc2.weight0',\n        # 'decoder.layers.0.mlp.experts.linear_fc2.weight1',\n        layer_number = name.split(\".\")[2]\n        convert_names = []\n        if \"pre_mlp_layernorm\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.post_attention_layernorm.weight\")\n            assert len(params) == 1\n        elif \"mlp.router.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.mlp.gate.weight\")\n            assert len(params) == 1\n        elif \"shared_experts.gate_weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.mlp.shared_expert_gate.weight\")\n            assert len(params) == 1\n        elif \"shared_experts.linear_fc1.weight\" in name:  # split gate_proj and up_proj\n            convert_names.append(f\"model.layers.{layer_number}.mlp.shared_expert.gate_proj.weight\")\n            convert_names.append(f\"model.layers.{layer_number}.mlp.shared_expert.up_proj.weight\")\n            assert len(params) == 2\n        elif \"shared_experts.linear_fc2.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.mlp.shared_expert.down_proj.weight\")\n            assert len(params) == 1\n        elif \"mlp.experts.linear_fc1\" in name:  # split gate_proj and up_proj\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight\")\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight\")\n            assert len(params) == 2\n        elif \"mlp.experts.linear_fc2\" in name:\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight\")\n            assert len(params) == 1\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n\n\nclass McoreToHFWeightConverterDpskv3(McoreToHFWeightConverterBase):\n    def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # mcore\n        # 'decoder.layers.0.input_layernorm.weight'\n        # 'decoder.layers.0.self_attention.linear_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_q_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_kv_down_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_kv_up_proj.layer_norm_weight'\n        # 'decoder.layers.0.self_attention.linear_kv_up_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_q_down_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_q_up_proj.weight'\n        # 'decoder.layers.0.self_attention.linear_q_up_proj.layer_norm_weight'\n        # hf\n        # 'model.layers.0.input_layernorm.weight'\n        # 'model.layers.0.self_attn.o_proj.weight'\n        # 'model.layers.0.self_attn.q_proj.weight'\n        # 'model.layers.0.self_attn.kv_a_proj_with_mqa.weight'\n        # 'model.layers.0.self_attn.kv_a_layernorm.weight'\n        # 'model.layers.0.self_attn.kv_b_proj.weight'\n        # 'model.layers.0.self_attn.q_a_proj.weight'\n        # 'model.layers.0.self_attn.q_b_proj.weight'\n        # 'model.layers.0.self_attn.q_a_layernorm.weight'\n        name_map_after_layer = {\n            \"input_layernorm.weight\": \"input_layernorm.weight\",\n            \"self_attention.linear_proj.weight\": \"self_attn.o_proj.weight\",\n            \"self_attention.linear_q_proj.weight\": \"self_attn.q_proj.weight\",\n            \"self_attention.linear_kv_down_proj.weight\": \"self_attn.kv_a_proj_with_mqa.weight\",\n            \"self_attention.linear_kv_up_proj.layer_norm_weight\": \"self_attn.kv_a_layernorm.weight\",\n            \"self_attention.linear_kv_up_proj.weight\": \"self_attn.kv_b_proj.weight\",\n            \"self_attention.linear_q_down_proj.weight\": \"self_attn.q_a_proj.weight\",\n            \"self_attention.linear_q_up_proj.weight\": \"self_attn.q_b_proj.weight\",\n            \"self_attention.linear_q_up_proj.layer_norm_weight\": \"self_attn.q_a_layernorm.weight\",\n        }\n        assert len(params) == 1\n        convert_names = []\n        layer_number = name.split(\".\")[2]\n        name_after_layer = name.split(f\".{layer_number}.\")[1]\n        convert_names.append(f\"model.layers.{layer_number}.{name_map_after_layer[name_after_layer]}\")\n        return convert_names, params\n\n    def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # mcore dense\n        # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight'\n        # 'decoder.layers.0.mlp.linear_fc2.weight'\n        # 'decoder.layers.0.mlp.linear_fc1.weight'\n        #       ---\n        # 'decoder.layers.1.mlp.shared_experts.linear_fc1.weight'\n        #       ---\n        # 'decoder.layers.1.mlp.shared_experts.linear_fc2.weight'\n        # hf dense\n        # 'model.layers.0.post_attention_layernorm.weight'\n        # 'model.layers.0.mlp.down_proj.weight'\n        # 'model.layers.0.mlp.gate_proj.weight'\n        # 'model.layers.0.mlp.up_proj.weight'\n        # 'model.layers.1.mlp.shared_experts.gate_proj.weight'\n        # 'model.layers.1.mlp.shared_experts.up_proj.weight'\n        # 'model.layers.1.mlp.shared_experts.down_proj.weight'\n\n        # mcore moe\n        # 'decoder.layers.1.pre_mlp_layernorm.weight'\n        # 'decoder.layers.1.mlp.router.weight'\n        # 'decoder.layers.1.mlp.router.expert_bias'\n        # 'decoder.layers.1.mlp.experts.linear_fc1.weight0'\n        #       ---\n        # 'decoder.layers.1.mlp.experts.linear_fc2.weight0'\n        # hf moe\n        # 'model.layers.1.post_attention_layernorm.weight'\n        # 'model.layers.1.mlp.gate.weight'\n        # 'model.layers.1.mlp.gate.e_score_correction_bias'\n        # 'model.layers.1.mlp.experts.0.gate_proj.weight'\n        # 'model.layers.1.mlp.experts.0.up_proj.weight'\n        # 'model.layers.1.mlp.experts.0.down_proj.weight'\n\n        name_map_after_layer = {\n            \"mlp.linear_fc1.layer_norm_weight\": \"post_attention_layernorm.weight\",\n            \"mlp.linear_fc2.weight\": \"mlp.down_proj.weight\",\n            \"mlp.shared_experts.linear_fc2.weight\": \"mlp.shared_experts.down_proj.weight\",\n            \"mlp.linear_fc1.weight\": [\"mlp.gate_proj.weight\", \"mlp.up_proj.weight\"],\n            \"mlp.shared_experts.linear_fc1.weight\": [\"mlp.shared_experts.gate_proj.weight\", \"mlp.shared_experts.up_proj.weight\"],\n            \"pre_mlp_layernorm.weight\": \"post_attention_layernorm.weight\",\n            \"mlp.router.weight\": \"mlp.gate.weight\",\n            \"mlp.router.expert_bias\": \"mlp.gate.e_score_correction_bias\",\n        }\n        convert_names = []\n        layer_number = name.split(\".\")[2]\n        name_after_layer = name.split(f\".{layer_number}.\")[1]\n        if name_after_layer in name_map_after_layer:\n            mapped_name = name_map_after_layer[name_after_layer]\n            if isinstance(mapped_name, list):\n                assert len(params) == len(mapped_name)\n                for one in mapped_name:\n                    convert_names.append(f\"model.layers.{layer_number}.{one}\")\n            else:\n                assert len(params) == 1\n                convert_names.append(f\"model.layers.{layer_number}.{mapped_name}\")\n        else:\n            if \"mlp.experts.linear_fc1.weight\" in name:\n                expert_id = name.split(\"weight\")[-1]\n                convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight\")\n                convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight\")\n                assert len(params) == 2\n            elif \"mlp.experts.linear_fc2.weight\" in name:\n                expert_id = name.split(\"weight\")[-1]\n                convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight\")\n                assert len(params) == 1\n            else:\n                raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n\n        return convert_names, params\n\n    def _convert_mtp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        assert self.mcore_config.mtp_num_layers == 1, \"only support one mtp layer for now\"\n        assert self.mcore_config.num_layers == 61, \"only support 61 layers for now\"\n        direct_name_mapping = {\"mtp.layers.0.enorm.weight\": \"model.layers.61.enorm.weight\", \"mtp.layers.0.hnorm.weight\": \"model.layers.61.hnorm.weight\", \"mtp.layers.0.eh_proj.weight\": \"model.layers.61.eh_proj.weight\", \"mtp.layers.0.final_layernorm.weight\": \"model.layers.61.shared_head.norm.weight\"}\n        if name in direct_name_mapping:\n            return [direct_name_mapping[name]], [params[0]]\n        assert \"mtp.layers.0.transformer_layer\" in name, \"only support transformer layer for now\"\n        # use proxy name to convert\n        proxy_name = name.replace(\"mtp.layers.0.transformer_layer\", \"decoder.layers.61\")\n        if \"self_attention\" in proxy_name or \"input_layernorm.weight\" in proxy_name:\n            convert_names, params = self._convert_attention_param(proxy_name, params)\n        elif \"mlp\" in proxy_name:\n            convert_names, params = self._convert_mlp_param(proxy_name, params)\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n\n    def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        direct_name_mapping = {\n            \"embedding.word_embeddings.weight\": \"model.embed_tokens.weight\",\n            \"decoder.final_layernorm.weight\": \"model.norm.weight\",\n            \"output_layer.weight\": \"lm_head.weight\",\n        }\n        if name in direct_name_mapping:\n            return [direct_name_mapping[name]], [params_one_group[0]]\n        if \"mtp\" in name:\n            return self._convert_mtp_param(name, params_one_group)\n        elif \"self_attention\" in name or \"input_layernorm.weight\" in name:\n            return self._convert_attention_param(name, params_one_group)\n        elif \"mlp\" in name:\n            return self._convert_mlp_param(name, params_one_group)\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n\n\nclass McoreToHFWeightConverterMixtral(McoreToHFWeightConverterDense):\n    def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # decoder.layers.0.mlp.router.weight\n        # decoder.layers.0.mlp.experts.linear_fc1.weight0 - weight7\n        # decoder.layers.0.mlp.experts.linear_fc2.weight0 - weight7\n\n        layer_number = name.split(\".\")[2]\n        convert_names = []\n        if \"pre_mlp_layernorm\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.post_attention_layernorm.weight\")\n        elif \"mlp.router.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.block_sparse_moe.gate.weight\")\n        elif \"mlp.experts.linear_fc1.weight\" in name:\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w1.weight\")\n            convert_names.append(f\"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w3.weight\")\n        elif \"mlp.experts.linear_fc2.weight\" in name:\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w2.weight\")\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n\n\nclass McoreToHFWeightConverterQwen3Moe(McoreToHFWeightConverterDense):\n    def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:\n        # qwen3 moe no share expert\n\n        # 'decoder.layers.0.pre_mlp_layernorm.weight',\n        # 'decoder.layers.0.mlp.router.weight',\n        # moe1\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight0',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight1',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight2',\n        # 'decoder.layers.0.mlp.experts.linear_fc1.weight3',\n        # moe2\n        # 'decoder.layers.0.mlp.experts.linear_fc2.weight0',\n        # 'decoder.layers.0.mlp.experts.linear_fc2.weight1',\n        layer_number = name.split(\".\")[2]\n        convert_names = []\n        if \"pre_mlp_layernorm\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.post_attention_layernorm.weight\")\n            assert len(params) == 1\n        elif \"mlp.router.weight\" in name:\n            convert_names.append(f\"model.layers.{layer_number}.mlp.gate.weight\")\n            assert len(params) == 1\n        elif \"mlp.experts.linear_fc1\" in name:  # split gate_proj and up_proj\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight\")\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight\")\n            assert len(params) == 2\n        elif \"mlp.experts.linear_fc2\" in name:\n            expert_id = name.split(\"weight\")[-1]\n            convert_names.append(f\"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight\")\n            assert len(params) == 1\n        else:\n            raise NotImplementedError(f\"Unsupported parameter name: {name}\")\n        return convert_names, params\n"
  },
  {
    "path": "siirl/models/model_utils/__init__.py",
    "content": ""
  },
  {
    "path": "siirl/models/model_utils/visual.py",
    "content": "from dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple\n\nimport torch\nimport transformers\nimport transformers.models\nfrom transformers.activations import ACT2FN\n\nfrom loguru import logger\n\nif TYPE_CHECKING:\n    from transformers import (\n        LlavaConfig,\n        PretrainedConfig,\n        PreTrainedModel,\n        ProcessorMixin,\n    )\n\n    from siirl.params import ModelArguments\n\ntransformers_logger = transformers.utils.logging.get_logger(__name__)\n\n\n@dataclass\nclass CompositeModel:\n    model_type: str\n    projector_key: str\n    vision_model_keys: List[str]\n    language_model_keys: List[str]\n    lora_conflict_keys: List[str]\n\n    def get_projector(self, module: \"torch.nn.Module\") -> \"torch.nn.Module\":\n        for key in self.projector_key.split(\".\"):\n            module = getattr(module, key)\n\n        return module\n\n\nCOMPOSITE_MODELS: Dict[str, \"CompositeModel\"] = {}\n\n\ndef _register_composite_model(\n    model_type: str,\n    projector_key: Optional[str] = None,\n    vision_model_keys: Optional[List[str]] = None,\n    language_model_keys: Optional[List[str]] = None,\n    lora_conflict_keys: Optional[List[str]] = None,\n):\n    COMPOSITE_MODELS[model_type] = CompositeModel(\n        model_type=model_type,\n        projector_key=projector_key or \"multi_modal_projector\",\n        vision_model_keys=vision_model_keys or [\"vision_tower\"],\n        language_model_keys=language_model_keys or [\"language_model\"],\n        lora_conflict_keys=lora_conflict_keys or [],\n    )\n\n\nclass LlavaMultiModalProjectorForYiVL(torch.nn.Module):\n    def __init__(self, config: \"LlavaConfig\") -> None:\n        super().__init__()\n\n        self.config = config\n        if config is None:\n            return\n\n        self.linear_1 = torch.nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)\n        self.linear_2 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)\n        self.linear_3 = torch.nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)\n        self.linear_4 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)\n        self.act = ACT2FN[config.projector_hidden_act]\n\n    def forward(self, image_features: \"torch.Tensor\") -> \"torch.Tensor\":\n        hidden_states = self.linear_1(image_features)\n        hidden_states = self.linear_2(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.linear_3(hidden_states)\n        hidden_states = self.linear_4(hidden_states)\n        if hidden_states.dtype == torch.float32:\n            if torch.is_autocast_enabled():\n                target_dtype = torch.get_autocast_gpu_dtype()\n            elif hasattr(self.config, \"_pre_quantization_dtype\"):\n                target_dtype = self.config._pre_quantization_dtype\n            else:\n                target_dtype = self.linear_1.weight.dtype\n\n            transformers_logger.warning(\"The hidden states seems to be silently casted in float32.\")\n            hidden_states = hidden_states.to(target_dtype)\n\n        return hidden_states\n\n\nclass LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):\n    def __init__(self, vision_hidden_size: int, text_hidden_size: int, projector_hidden_act: str) -> None:\n        super().__init__(config=None)\n\n        self.linear_1 = torch.nn.Linear(vision_hidden_size, text_hidden_size, bias=True)\n        self.linear_2 = torch.nn.LayerNorm(text_hidden_size, bias=True)\n        self.linear_3 = torch.nn.Linear(text_hidden_size, text_hidden_size, bias=True)\n        self.linear_4 = torch.nn.LayerNorm(text_hidden_size, bias=True)\n        self.act = ACT2FN[projector_hidden_act]\n\n\ndef autocast_projector_dtype(model: \"PreTrainedModel\", model_args: \"ModelArguments\") -> None:\n    r\"\"\"\n    Casts projector output to half precision for fine-tuning quantized VLMs.\n    \"\"\"\n\n    def _mm_projector_forward_post_hook(module: \"torch.nn.Module\", args: Tuple[\"torch.Tensor\"], output: \"torch.Tensor\") -> \"torch.Tensor\":\n        return output.to(model_args.compute_dtype)\n\n    if getattr(model, \"quantization_method\", None):\n        model_type = getattr(model.config, \"model_type\", None)\n        if model_type in COMPOSITE_MODELS:\n            mm_projector = COMPOSITE_MODELS[model_type].get_projector(model)\n        else:\n            return\n\n        logger.info(f\"Casting multimodal projector outputs in {model_args.compute_dtype}.\")\n        mm_projector.register_forward_hook(_mm_projector_forward_post_hook)\n\n\ndef configure_visual_model(config: \"PretrainedConfig\") -> None:\n    r\"\"\"\n    Patches VLMs before loading them.\n    \"\"\"\n    if getattr(config, \"text_config\", None) and not getattr(config, \"hidden_size\", None):\n        # required for ds zero3 and valuehead models\n        setattr(config, \"hidden_size\", getattr(config.text_config, \"hidden_size\", None))\n\n    if getattr(config, \"is_yi_vl_derived_model\", None):\n        logger.info(\"Detected Yi-VL model, applying projector patch.\")\n        transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL\n\n\ndef get_forbidden_modules(config: \"PretrainedConfig\", finetuning_args: \"FinetuningArguments\") -> Set[str]:\n    r\"\"\"\n    Freezes vision tower and language model for VLM full/freeze tuning.\n    \"\"\"\n    model_type = getattr(config, \"model_type\", None)\n    forbidden_modules = set()\n    if model_type in COMPOSITE_MODELS:\n        if finetuning_args.freeze_vision_tower:\n            vision_model_keys = COMPOSITE_MODELS[model_type].vision_model_keys\n            logger.info(f\"Set vision model not trainable: {vision_model_keys}.\")\n            forbidden_modules.update(vision_model_keys)\n\n        if finetuning_args.freeze_multi_modal_projector:\n            projector_key = COMPOSITE_MODELS[model_type].projector_key\n            logger.info(f\"Set multi model projector not trainable: {projector_key}.\")\n            forbidden_modules.add(projector_key)\n\n        if finetuning_args.train_mm_proj_only:\n            language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys\n            logger.info(f\"Set language model not trainable: {language_model_keys}.\")\n            forbidden_modules.update(language_model_keys)\n\n    return forbidden_modules\n\n\ndef get_image_seqlen(config: \"PretrainedConfig\") -> int:\n    r\"\"\"\n    Computes the number of special tokens per image.\n    \"\"\"\n    model_type = getattr(config, \"model_type\", None)\n    if model_type == \"llava\":\n        image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2\n        if getattr(config, \"vision_feature_select_strategy\", \"default\") == \"full\":  # add [CLS] token\n            image_seqlen += 1\n    elif model_type == \"paligemma\":\n        image_seqlen = config.vision_config.num_image_tokens\n    else:\n        image_seqlen = -1\n\n    return image_seqlen\n\n\ndef get_patch_size(config: \"PretrainedConfig\", processor: \"ProcessorMixin\") -> int:\n    r\"\"\"\n    Computes the patch size of the vit.\n    \"\"\"\n    patch_size = getattr(config.vision_config, \"patch_size\", getattr(processor, \"patch_size\", -1))\n    return patch_size\n\n\ndef get_vision_feature_select_strategy(config: \"PretrainedConfig\", processor: \"ProcessorMixin\") -> int:\n    r\"\"\"\n    Get the vision_feature_select_strategy.\n    \"\"\"\n    vision_feature_select_strategy = getattr(\n        config,\n        \"vision_feature_select_strategy\",\n        getattr(processor, \"vision_feature_select_strategy\", \"default\"),\n    )\n    return vision_feature_select_strategy\n\n\ndef patch_target_modules(\n    model: \"PreTrainedModel\",\n    finetuning_args: \"FinetuningArguments\",\n    target_modules: Sequence[str],\n) -> List[str]:\n    r\"\"\"\n    Freezes vision tower for VLM LoRA tuning.\n    \"\"\"\n    model_type = getattr(model.config, \"model_type\", None)\n    if model_type in COMPOSITE_MODELS:\n        forbidden_modules = get_forbidden_modules(model.config, finetuning_args)\n        forbidden_modules.update(COMPOSITE_MODELS[model_type].lora_conflict_keys)\n        module_names = []\n        for name, _ in model.named_modules():\n            if any(target_module in name for target_module in target_modules) and not any(forbidden_module in name for forbidden_module in forbidden_modules):\n                module_names.append(name)\n\n        return module_names\n    else:\n        return target_modules\n\n\n_register_composite_model(\n    model_type=\"llava\",\n)\n\n_register_composite_model(\n    model_type=\"llava_next\",\n)\n\n_register_composite_model(\n    model_type=\"llava_next_video\",\n)\n\n_register_composite_model(\n    model_type=\"minicpmv\",\n    projector_key=\"resampler\",\n    vision_model_keys=[\"vpm\"],\n    language_model_keys=[\"llm\"],\n)\n\n_register_composite_model(\n    model_type=\"minicpmo\",\n    projector_key=\"resampler\",\n    vision_model_keys=[\n        \"vpm\",\n        \"apm\",\n        \"audio_avg_pooler\",\n        \"audio_projection_layer\",\n        \"tts\",\n    ],\n    language_model_keys=[\"llm\"],\n    lora_conflict_keys=[\"audio_projection_layer\"],\n)\n\n_register_composite_model(\n    model_type=\"paligemma\",\n)\n\n_register_composite_model(\n    model_type=\"video_llava\",\n)\n\n_register_composite_model(\n    model_type=\"mllama\",\n    vision_model_keys=[\"vision_model\"],\n)\n\n_register_composite_model(\n    model_type=\"qwen2_audio\",\n    vision_model_keys=[\"audio_tower\"],\n)\n\n_register_composite_model(\n    model_type=\"qwen2_vl\",\n    projector_key=\"visual.merger\",\n    vision_model_keys=[\"visual.patch_embed\", \"visual.blocks\"],\n    language_model_keys=[\"model\", \"lm_head\"],\n    lora_conflict_keys=[\"patch_embed\"],\n)\n\n_register_composite_model(\n    model_type=\"qwen2_5_vl\",\n    projector_key=\"visual.merger\",\n    vision_model_keys=[\"visual.patch_embed\", \"visual.blocks\"],\n    language_model_keys=[\"model\", \"lm_head\"],\n    lora_conflict_keys=[\"patch_embed\"],\n)\n"
  },
  {
    "path": "siirl/models/patcher.py",
    "content": "# Copyright 2025 the LlamaFactory team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom types import MethodType\nfrom typing import TYPE_CHECKING, Any, Dict\n\nimport torch\nfrom peft import PeftModel\nfrom transformers import (\n    PreTrainedModel,\n    PreTrainedTokenizerBase,\n    is_torch_npu_available,\n)\n\nfrom loguru import logger\nfrom siirl.models.model_utils.visual import (\n    get_image_seqlen,\n    get_patch_size,\n    get_vision_feature_select_strategy,\n)\n\nif TYPE_CHECKING:\n    from transformers import PretrainedConfig, PreTrainedTokenizer, ProcessorMixin\n\n    from siirl.params import ModelArguments\n\n\ndef patch_tokenizer(tokenizer: \"PreTrainedTokenizer\", model_args: \"ModelArguments\", config: \"PretrainedConfig\") -> None:\n    if \"PreTrainedTokenizerBase\" not in str(tokenizer._pad.__func__):\n        tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)\n\n    if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length:\n        tokenizer.model_max_length = model_args.model_max_length\n\n    if model_args.new_special_tokens is not None:\n        num_added_tokens = tokenizer.add_special_tokens(\n            dict(additional_special_tokens=model_args.new_special_tokens),\n            replace_additional_special_tokens=False,\n        )\n        logger.info(\"Add {} to special tokens.\".format(\",\".join(model_args.new_special_tokens)))\n        if num_added_tokens > 0 and not model_args.resize_vocab:\n            model_args.resize_vocab = True\n            logger.warning(\"New tokens have been added, changed `resize_vocab` to True.\")\n\n    if \"InternVL\" in config.architectures[0]:\n\n        def eos_token_id_patch(self):\n            return self.super().eos_token_id\n\n        tokenizer.__class__.eos_token_id = property(eos_token_id_patch)\n        tokenizer.eos_token = \"<|im_end|>\"\n        tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids(\"<|im_end|>\")\n\n\ndef patch_processor(\n    processor: \"ProcessorMixin\",\n    config: \"PretrainedConfig\",\n    tokenizer: \"PreTrainedTokenizer\",\n    model_args: \"ModelArguments\",\n) -> None:\n    setattr(processor, \"tokenizer\", tokenizer)\n    if getattr(config, \"vision_config\", None) is not None:  # visual models\n        setattr(processor, \"image_seqlen\", get_image_seqlen(config))\n        setattr(processor, \"patch_size\", get_patch_size(config, processor))\n        setattr(processor, \"image_max_pixels\", model_args.image_max_pixels)\n        setattr(processor, \"image_min_pixels\", model_args.image_min_pixels)\n        setattr(processor, \"video_max_pixels\", model_args.video_max_pixels)\n        setattr(processor, \"video_min_pixels\", model_args.video_min_pixels)\n        setattr(processor, \"video_fps\", model_args.video_fps)\n        setattr(processor, \"video_maxlen\", model_args.video_maxlen)\n        setattr(\n            processor,\n            \"vision_feature_select_strategy\",\n            get_vision_feature_select_strategy(config, processor),\n        )\n"
  },
  {
    "path": "siirl/models/qwen2/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/models/qwen2/megatron/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .modeling_qwen2_megatron import (\n    ParallelQwen2ForCausalLM,\n    # rmpad with megatron\n    ParallelQwen2ForCausalLMRmPad,\n    # rmpad with megatron and pipeline parallelism\n    ParallelQwen2ForCausalLMRmPadPP,\n    ParallelQwen2ForValueRmPad,\n    ParallelQwen2ForValueRmPadPP,\n    # original model with megatron\n    ParallelQwen2Model,\n)\n\n__all__ = [\n    \"ParallelQwen2ForCausalLM\",\n    \"ParallelQwen2ForCausalLMRmPad\",\n    \"ParallelQwen2ForCausalLMRmPadPP\",\n    \"ParallelQwen2ForValueRmPad\",\n    \"ParallelQwen2ForValueRmPadPP\",\n    \"ParallelQwen2Model\",\n]\n"
  },
  {
    "path": "siirl/models/qwen2/megatron/checkpoint_utils/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\n\nfrom siirl.utils.extras.device import get_device_id, get_torch_device\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False):\n    \"\"\"Load merged state_dict to sharded Megatron module in training.\"\"\"\n    from megatron.core import DistributedDataParallel as LocalDDP\n    from megatron.core import mpu\n    from megatron.core.transformer.module import Float16Module\n    from torch.nn.parallel import DistributedDataParallel as torchDDP\n\n    from siirl.utils.megatron.megatron_utils import print_rank_0, unwrap_model\n\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    def fetch_params(module):\n        for param in module.parameters():\n            torch.distributed.fetch(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group())\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if torch.distributed.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, (list, tuple)):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f\"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}\"\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        gpt_model_module = _get_gpt_model(models[i])\n        assert len(gpt_model_module.model.layers) == num_layers_per_model\n\n    def _fetch_tensor(tensor, name) -> torch.Tensor:\n        \"\"\"fetch tensor\"\"\"\n        nonlocal state_dict\n        if tensor is not None:\n            tensor = tensor.data.copy_(state_dict[name], non_blocking=True)\n\n    def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards\"\"\"\n        nonlocal state_dict\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if name in state_dict:\n            full_weight = state_dict[name]\n\n            if mutate_func is not None:\n                full_weight = mutate_func(full_weight)\n            tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n            if tensor is not None:\n                tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)\n        else:\n            print(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards\"\"\"\n        nonlocal state_dict\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if name in state_dict:\n            full_weight = state_dict[name]\n\n            if mutate_func is not None:\n                full_weight = mutate_func(full_weight)\n            tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n            if tensor is not None:\n                tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)\n        else:\n            print(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:\n        \"\"\"fetch gate_up tensor in tp shards\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        if gate_name in state_dict and up_name in state_dict:\n            gate_weight = state_dict[gate_name]\n            up_weight = state_dict[up_name]\n            new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id())\n            for i in range(tp_size):\n                intermediate_size_tp = config.intermediate_size // tp_size\n                gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0))\n\n            tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)\n            if tensor is not None:\n                tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)\n        else:\n            print(f\"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading\")\n\n    def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:\n        \"\"\"fetch tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        assert q_name in state_dict and k_name in state_dict and v_name in state_dict\n        full_weight_q = state_dict[q_name]\n        full_weight_k = state_dict[k_name]\n        full_weight_v = state_dict[v_name]\n\n        hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n        if config.num_key_value_heads >= tp_size:\n            q_size_tp = config.hidden_size // tp_size\n            kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n            total_size = q_size_tp + 2 * kv_size_tp\n            if not bias:\n                new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id())\n            else:\n                new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id())\n            for i in range(tp_size):\n                q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]\n                v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]\n                new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))\n\n        else:\n            q_size_tp = config.hidden_size // tp_size\n            kv_size_tp = hidden_size_per_head\n            total_size = q_size_tp + 2 * kv_size_tp\n            if not bias:\n                new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id())\n            else:\n                new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id())\n            for i in range(tp_size):\n                q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head\n                end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head\n                k_part = full_weight_k[start_idx:end_idx]\n                v_part = full_weight_v[start_idx:end_idx]\n                new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))\n\n        tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)\n        if tensor is not None:\n            tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)\n\n    # Embeddings\n    # -------------------\n    print_rank_0(\"loading embeddings...\")\n    gpt_model_module = _get_gpt_model(models[0])\n    if pp_rank == 0:\n        embed_tokens_weight = gpt_model_module.model.embed_tokens.weight\n        _fetch_tp_shard_tensor_vocab(embed_tokens_weight, \"model.embed_tokens.weight\")\n\n    # Transformer layers\n    # -------------------\n    layer_map = _megatron_calc_layer_map(config)\n\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    num_layer_per_pp = config.num_hidden_layers // pp_size\n    vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n\n    layer_list = []\n    if vpp_size is not None:\n        for vpp_rank in range(vpp_size):\n            num_layer_vpp_chunk = num_layer_per_pp // vpp_size\n            num_layer_this_model = num_layer_vpp_chunk\n            offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + (mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk)\n            layer_list.extend(list(range(offset, offset + num_layer_this_model)))\n    else:\n        num_layer_this_model = num_layer_per_pp\n        offset = pp_rank * num_layer_per_pp\n        layer_list.extend(list(range(offset, offset + num_layer_this_model)))\n\n    for layer in layer_list:\n        print(f\"{torch.distributed.get_rank()} loading layer #{layer}...\")\n        layer_name = f\"model.layers.{layer}\"\n        dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]\n\n        print(f\"{torch.distributed.get_rank()} offset: {offset}, num_layer_this_model: {num_layer_this_model}, layer_name: {layer_name}, layer_map[layer]: {layer_map[layer]}\")\n\n        gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])\n        sync_layer = gpt_model_module.model.layers[dst_layer_idx]\n\n        _fetch_tensor(\n            sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.input_layernorm.weight\",\n        )\n\n        _fetch_tp_shard_tensor_qkv(\n            sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.self_attn.q_proj.weight\",\n            f\"{layer_name}.self_attn.k_proj.weight\",\n            f\"{layer_name}.self_attn.v_proj.weight\",\n        )\n\n        _fetch_tp_shard_tensor_qkv(\n            sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.self_attn.q_proj.bias\",\n            f\"{layer_name}.self_attn.k_proj.bias\",\n            f\"{layer_name}.self_attn.v_proj.bias\",\n            bias=True,\n        )\n\n        _fetch_tp_shard_tensor(\n            sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.self_attn.o_proj.weight\",\n            chunk_dim=1,\n        )\n\n        _fetch_tensor(\n            sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.post_attention_layernorm.weight\",\n        )\n\n        _fetch_tp_shard_tensor_gate_up(\n            sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.mlp.gate_proj.weight\",\n            f\"{layer_name}.mlp.up_proj.weight\",\n        )\n\n        _fetch_tp_shard_tensor(\n            sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,\n            f\"{layer_name}.mlp.down_proj.weight\",\n            chunk_dim=1,\n        )\n    # Final Layernorm\n    # -------------------\n    print_rank_0(\"loading final layernorm...\")\n    gpt_model_module = _get_gpt_model(models[-1])\n    _fetch_tensor(\n        getattr(gpt_model_module.model.norm, \"weight\", None),\n        \"model.norm.weight\",\n    )\n\n    if tie_word_embeddings:\n        print_rank_0(\"tie_word_embeddings skip load lm_head\")\n    else:\n        print_rank_0(\"loading lm_head...\")\n        if pp_rank + 1 == pp_size:\n            lm_head_weight = gpt_model_module.lm_head.weight\n\n            if is_value_model:\n                if \"lm_head.weight\" in state_dict and state_dict[\"lm_head.weight\"].shape[0] == 1:\n                    _fetch_tensor(lm_head_weight, \"lm_head.weight\")\n                    print_rank_0(\"load lm_head from value_head weight\")\n                elif \"reward_head.weight\" in state_dict and state_dict[\"reward_head.weight\"].shape[0] == 1:\n                    _fetch_tensor(lm_head_weight, \"reward_head.weight\")\n                    print_rank_0(\"load lm_head from value_head weight\")\n                else:\n                    _fetch_tensor(None, \"lm_head.weight\")\n                    print_rank_0(\"fail to match lm_head in value_model\")\n\n            else:\n                _fetch_tp_shard_tensor(lm_head_weight, \"lm_head.weight\")\n\n    dist.barrier()\n    get_torch_device().empty_cache()\n    print_rank_0(f\"loading megatron ckpt done, time elapsed {time.time() - start_time}s\")\n"
  },
  {
    "path": "siirl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\n\nfrom siirl.utils.extras.device import get_device_id, get_torch_device\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef load_state_dict_to_megatron_qwen2(state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False):\n    \"\"\"Load merged state_dict to sharded Megatron module in training.\"\"\"\n    from megatron.core import DistributedDataParallel as LocalDDP\n    from megatron.core import mpu\n    from megatron.core.transformer.module import Float16Module\n    from torch.nn.parallel import DistributedDataParallel as torchDDP\n\n    from siirl.utils.megatron.megatron_utils import print_rank_0, unwrap_model\n\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    def broadcast_params(module):\n        for param in module.parameters():\n            torch.distributed.broadcast(param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group())\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if torch.distributed.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, (list, tuple)):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f\"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}\"\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        gpt_model_module = _get_gpt_model(models[i])\n        assert len(gpt_model_module.model.layers) == num_layers_per_model\n\n    def _broadcast_tensor(tensor, name) -> torch.Tensor:\n        \"\"\"broadcast tensor from rank0 across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                weight = state_dict[name]\n                tensor_shape = weight.shape\n            else:\n                tensor_shape = None\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not in state_dict, skip load\")\n            return\n\n        if tensor is None:\n            tensor = torch.empty(\n                tensor_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        if torch.distributed.get_rank() == 0:\n            tensor.data.copy_(weight)\n        dist.broadcast(tensor, src=0, group=mp_group)\n\n    def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                full_weight = state_dict[name]\n\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            if name in state_dict:\n                full_weight = state_dict[name]\n                if mutate_func is not None:\n                    full_weight = mutate_func(full_weight)\n                tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)\n                chunk_shape = tensor_chunk[0].shape\n            else:\n                chunk_shape = None\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, f\"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}\"\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            gate_weight = state_dict[gate_name]\n            up_weight = state_dict[up_name]\n            new_gate_up_weight = torch.empty(config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id())\n            for i in range(tp_size):\n                intermediate_size_tp = config.intermediate_size // tp_size\n                gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp]\n                new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0))\n\n            tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, f\"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}\"\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_rank = mpu.get_tensor_model_parallel_rank()\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        if torch.distributed.get_rank() == 0:\n            assert q_name in state_dict and k_name in state_dict and v_name in state_dict\n            full_weight_q = state_dict[q_name]\n            full_weight_k = state_dict[k_name]\n            full_weight_v = state_dict[v_name]\n\n            hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                if not bias:\n                    new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id())\n                else:\n                    new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id())\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp]\n                    new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))\n\n            else:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                if not bias:\n                    new_weight_qkv = torch.empty(total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id())\n                else:\n                    new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id())\n                for i in range(tp_size):\n                    q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp]\n                    start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head\n                    end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head\n                    k_part = full_weight_k[start_idx:end_idx]\n                    v_part = full_weight_v[start_idx:end_idx]\n                    new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))\n\n            tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)\n            chunk_shape = tensor_chunk[0].shape\n        else:\n            chunk_shape = None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=0, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading\")\n            return\n\n        if tensor is None:\n            sync_tensor = torch.empty(\n                chunk_shape,\n                dtype=params_dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n        else:\n            assert tensor.shape == chunk_shape, f\"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}\"\n            sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False)\n\n        for i in range(tp_size):\n            if torch.distributed.get_rank() == 0:\n                sync_tensor.data.copy_(tensor_chunk[i])\n            dist.broadcast(sync_tensor, src=0, group=mp_group)\n            if (i == tp_rank) and (tensor is not None):\n                tensor.data.copy_(sync_tensor)\n\n    if dp_rank == 0:\n        # Embeddings\n        # -------------------\n        print_rank_0(\"loading embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        embed_tokens_weight = None\n        if pp_rank == 0:\n            embed_tokens_weight = gpt_model_module.model.embed_tokens.weight\n        _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, \"model.embed_tokens.weight\")\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n\n        for layer in range(config.num_hidden_layers):\n            print_rank_0(f\"loading layer #{layer}...\")\n            layer_name = f\"model.layers.{layer}\"\n            dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])\n            sync_layer = gpt_model_module.model.layers[dst_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.input_layernorm.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.q_proj.bias\",\n                f\"{layer_name}.self_attn.k_proj.bias\",\n                f\"{layer_name}.self_attn.v_proj.bias\",\n                bias=True,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                chunk_dim=1,\n            )\n\n            _broadcast_tensor(\n                sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                chunk_dim=1,\n            )\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"loading final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.model.norm, \"weight\", None),\n            \"model.norm.weight\",\n        )\n\n        if tie_word_embeddings:\n            print_rank_0(\"tie_word_embeddings skip load lm_head\")\n        else:\n            print_rank_0(\"loading lm_head...\")\n            lm_head_weight = None\n            if pp_rank + 1 == pp_size:\n                lm_head_weight = gpt_model_module.lm_head.weight\n\n            if is_value_model:\n                if \"lm_head.weight\" in state_dict and state_dict[\"lm_head.weight\"].shape[0] == 1:\n                    _broadcast_tensor(lm_head_weight, \"lm_head.weight\")\n                    print_rank_0(\"load lm_head from value_head weight\")\n                elif \"reward_head.weight\" in state_dict and state_dict[\"reward_head.weight\"].shape[0] == 1:\n                    _broadcast_tensor(lm_head_weight, \"reward_head.weight\")\n                    print_rank_0(\"load lm_head from value_head weight\")\n                else:\n                    _broadcast_tensor(None, \"lm_head.weight\")\n                    print_rank_0(\"fail to match lm_head in value_model\")\n\n            else:\n                _broadcast_tp_shard_tensor(lm_head_weight, \"lm_head.weight\")\n\n    dist.barrier()\n    # Broadcast weights inside data parallel groups\n    for wrapped_model in wrapped_models:\n        broadcast_params(wrapped_model)\n\n    get_torch_device().empty_cache()\n    print_rank_0(f\"loading megatron ckpt done, time elapsed {time.time() - start_time}s\")\n"
  },
  {
    "path": "siirl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nimport torch.distributed as dist\nfrom megatron.core import mpu\nfrom megatron.core.distributed import DistributedDataParallel as LocalDDP\nfrom megatron.core.transformer.module import Float16Module\nfrom torch.nn.parallel import DistributedDataParallel as torchDDP\n\nfrom siirl.utils.extras.device import get_device_id, get_torch_device\nfrom siirl.utils.megatron.megatron_utils import print_rank_0, unwrap_model\n\n\ndef _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):\n    \"\"\"given TP,DP,PP rank to get the global rank.\"\"\"\n\n    tp_size = mpu.get_tensor_model_parallel_world_size()\n    dp_size = mpu.get_data_parallel_world_size()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), f\"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}\"\n    # We only support TP-DP-PP grouping, for correctness when resharding\n    return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank\n\n\ndef _megatron_calc_layer_map(config):\n    \"\"\"Calculate the mapping of global layer_idx to local layer_idx\n    Returns:\n        layer_map (Dict: int -> tuple(int, int, int)):\n            mapping from the global layer index to\n            a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)\n    \"\"\"\n    from megatron.core import mpu\n\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n\n    layer_map = dict()\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    for pp_rank_idx in range(pp_size):\n        for virtual_pp_rank_idx in range(virtual_pp_size):\n            layer_offset = virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model\n            for layer_idx in range(num_layers_per_model):\n                layer_map[layer_offset + layer_idx] = (\n                    pp_rank_idx,\n                    virtual_pp_rank_idx,\n                    layer_idx,\n                )\n    return layer_map\n\n\ndef merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):\n    \"\"\"Merge sharded parameters of a Megatron module into a merged checkpoint.\n\n    Args:\n        wrapped_models (list of megatron.core.distributed.DistributedDataParallel):\n            The local DDP wrapped megatron modules.\n        config (str or None):\n            HF config for model\n        dtype: model params type\n        is_value_model: if model is value model\n        tie_word_embeddings: tie_word_embeddings\n    Returns:\n        state_dict (dict):\n            The merged state_dict in rank 0, and an empty dictionary in other ranks.\n    \"\"\"\n    start_time = time.time()\n\n    def _get_gpt_model(model):\n        return model\n\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_size = mpu.get_pipeline_model_parallel_world_size()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1\n    mp_group = mpu.get_model_parallel_group()\n\n    if dist.get_rank() == 0:\n        assert mp_group.rank() == 0, f\"mp_rank:[{mp_group.rank}] != 0 on rank #0\"\n        assert pp_rank == 0, f\"pp_rank:[{pp_rank}] != 0 on rank #0\"\n        assert dp_rank == 0, f\"dp_rank:[{dp_rank}] != 0 on rank #0\"\n\n    if not isinstance(wrapped_models, (list, tuple)):\n        wrapped_models = list(wrapped_models)\n\n    assert len(wrapped_models) == virtual_pp_size\n    num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size\n    assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers\n\n    models = [None] * len(wrapped_models)\n\n    for i, wrapped_model in enumerate(wrapped_models):\n        models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))\n        assert len(models[i].model.layers) == num_layers_per_model, \"len model layers {} not equal to num_layers_per_model {}\".format(len(models[i].model.layers), num_layers_per_model)\n\n    state_dict = dict()\n\n    def _get_cpu_tensor(tensor: torch.Tensor):\n        if tensor is None:\n            return None\n        if tensor.device == torch.device(\"cpu\"):\n            return tensor.detach().clone()\n        return tensor.detach().cpu()\n\n    def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        if torch.distributed.get_rank() == src_rank:\n            if tensor is None:\n                weight = None\n                tensor_shape = None\n            else:\n                weight = tensor\n                tensor_shape = weight.shape\n        else:\n            weight = None\n            tensor_shape = None\n\n        obj_list = [tensor_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        tensor_shape = obj_list[0]\n\n        if tensor_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tensor:[{name}] not exist, skip collect\")\n            return\n\n        if weight is None:\n            weight = torch.empty(\n                tensor_shape,\n                dtype=dtype,\n                device=get_device_id(),\n                requires_grad=False,\n            )\n\n        dist.broadcast(weight, src=src_rank, group=mp_group)\n\n        if torch.distributed.get_rank() == 0:\n            state_dict[name] = _get_cpu_tensor(weight)\n\n    def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=concat_dim)\n            if mutate_func is not None:\n                full_tensor = mutate_func(full_tensor)\n            state_dict[name] = full_tensor\n\n    def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            intermediate_size_tp = config.intermediate_size // tp_size\n            gate_weight_list = []\n            up_weight_list = []\n            for i in range(tp_size):\n                gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]\n                gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]\n                up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]\n                gate_weight_list.append(gate_weight_tp)\n                up_weight_list.append(up_weight_tp)\n\n            state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)\n            state_dict[up_name] = torch.cat(up_weight_list, dim=0)\n\n    def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):\n        \"\"\"broadcast tensor in tp shards across mp_group\"\"\"\n        nonlocal state_dict\n        nonlocal mp_group\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)\n\n        chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None\n\n        obj_list = [chunk_shape]\n        dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)\n        chunk_shape = obj_list[0]\n        if chunk_shape is None:\n            # all or none ranks in the mp_group should reach here\n            print_rank_0(f\"tp_shard tensor:[{q_name}] not exist, skip collecting\")\n            return\n\n        buffer_tensor = torch.empty(\n            chunk_shape,\n            dtype=dtype,\n            device=get_device_id(),\n            requires_grad=False,\n        )\n\n        chunk_tensors = [None] * tp_size\n\n        for i in range(tp_size):\n            cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)\n            sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor\n            dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)\n\n            if torch.distributed.get_rank() == 0:\n                chunk_tensors[i] = _get_cpu_tensor(sync_tensor)\n\n        if torch.distributed.get_rank() == 0:\n            full_tensor = torch.concat(chunk_tensors, dim=0)\n            q_weight_list = []\n            k_weight_list = []\n            v_weight_list = []\n            hidden_size_per_head = config.hidden_size // config.num_attention_heads\n\n            if config.num_key_value_heads >= tp_size:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_part = qkv_part[:q_size_tp]\n                    k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]\n                    v_part = qkv_part[q_size_tp + kv_size_tp : total_size]\n                    q_weight_list.append(q_part)\n                    k_weight_list.append(k_part)\n                    v_weight_list.append(v_part)\n            else:\n                q_size_tp = config.hidden_size // tp_size\n                kv_size_tp = hidden_size_per_head\n                total_size = q_size_tp + 2 * kv_size_tp\n                for i in range(tp_size):\n                    qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                    q_part = qkv_part[:q_size_tp]\n                    k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp]\n                    v_part = qkv_part[q_size_tp + kv_size_tp : total_size]\n                    q_weight_list.append(q_part)\n                    if i * config.num_key_value_heads % tp_size == 0:\n                        k_weight_list.append(k_part)\n                        v_weight_list.append(v_part)\n\n            state_dict[q_name] = torch.cat(q_weight_list, dim=0)\n            state_dict[k_name] = torch.cat(k_weight_list, dim=0)\n            state_dict[v_name] = torch.cat(v_weight_list, dim=0)\n\n    # empty cache before collecting weights\n    get_torch_device().empty_cache()\n    # Embeddings\n    # -------------------\n    if dp_rank == 0:\n        # Embeddings\n        # -------------------\n        print_rank_0(\"collecting embeddings...\")\n        gpt_model_module = _get_gpt_model(models[0])\n        _broadcast_tp_shard_tensor(\n            gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None,\n            \"model.embed_tokens.weight\",\n            src_pp_rank=0,\n        )\n\n        # Transformer layers\n        # -------------------\n        layer_map = _megatron_calc_layer_map(config)\n        for layer in range(config.num_hidden_layers):\n            print_rank_0(f\"collecting layer #{layer}...\")\n            layer_name = f\"model.layers.{layer}\"\n            src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]\n\n            gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])\n            sync_layer = gpt_model_module.model.layers[src_layer_idx]\n\n            _broadcast_tensor(\n                sync_layer.input_layernorm.weight,\n                f\"{layer_name}.input_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.weight,\n                f\"{layer_name}.self_attn.q_proj.weight\",\n                f\"{layer_name}.self_attn.k_proj.weight\",\n                f\"{layer_name}.self_attn.v_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_qkv(\n                sync_layer.self_attn.qkv_proj.bias,\n                f\"{layer_name}.self_attn.q_proj.bias\",\n                f\"{layer_name}.self_attn.k_proj.bias\",\n                f\"{layer_name}.self_attn.v_proj.bias\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.self_attn.o_proj.weight,\n                f\"{layer_name}.self_attn.o_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tensor(\n                sync_layer.post_attention_layernorm.weight,\n                f\"{layer_name}.post_attention_layernorm.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor_gate_up(\n                sync_layer.mlp.gate_up_proj.weight,\n                f\"{layer_name}.mlp.gate_proj.weight\",\n                f\"{layer_name}.mlp.up_proj.weight\",\n                src_pp_rank=src_pp_rank,\n            )\n\n            _broadcast_tp_shard_tensor(\n                sync_layer.mlp.down_proj.weight,\n                f\"{layer_name}.mlp.down_proj.weight\",\n                concat_dim=1,\n                src_pp_rank=src_pp_rank,\n            )\n\n        # Final Layernorm\n        # -------------------\n        print_rank_0(\"collecting final layernorm...\")\n        gpt_model_module = _get_gpt_model(models[-1])\n        _broadcast_tensor(\n            getattr(gpt_model_module.model.norm, \"weight\", None),\n            \"model.norm.weight\",\n            src_pp_rank=pp_size - 1,\n        )\n\n        if tie_word_embeddings:\n            print_rank_0(\"tie word embedding skip load lm_head...\")\n        else:\n            print_rank_0(\"collecting lm_head...\")\n\n            if is_value_model:\n                _broadcast_tensor(\n                    gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None,\n                    \"lm_head.weight\",\n                    src_pp_rank=pp_size - 1,\n                )\n                _broadcast_tensor(\n                    gpt_model_module.reward_head.weight if pp_rank == pp_size - 1 and getattr(gpt_model_module, \"reward_weight\", None) is not None else None,\n                    \"reward_head.weight\",\n                    src_pp_rank=pp_size - 1,\n                )\n\n            else:\n                _broadcast_tp_shard_tensor(\n                    getattr(gpt_model_module.lm_head, \"weight\", None) if pp_rank == pp_size - 1 else None,\n                    \"lm_head.weight\",\n                    src_pp_rank=pp_size - 1,\n                )\n\n    dist.barrier()\n\n    get_torch_device().empty_cache()\n    if torch.distributed.get_rank() == 0:\n        for k, v in state_dict.items():\n            if dtype != v.dtype:\n                state_dict[k] = v.to(dtype)\n\n    print_rank_0(f\"merge megatron ckpt done, time elapsed {time.time() - start_time}s\")\n    return state_dict\n"
  },
  {
    "path": "siirl/models/qwen2/megatron/layers/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .parallel_attention import ParallelQwen2Attention\nfrom .parallel_decoder import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad\nfrom .parallel_mlp import ParallelQwen2MLP\nfrom .parallel_rmsnorm import ParallelQwen2RMSNorm\n\n__all__ = [\"ParallelQwen2Attention\", \"ParallelQwen2DecoderLayer\", \"ParallelQwen2DecoderLayerRmPad\", \"ParallelQwen2MLP\", \"ParallelQwen2RMSNorm\"]\n"
  },
  {
    "path": "siirl/models/qwen2/megatron/layers/parallel_attention.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nfrom typing import Optional, Tuple\n\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom transformers.utils import is_flash_attn_2_available\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_varlen_func\n    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\nimport torch\nfrom flash_attn.layers.rotary import apply_rotary_emb\nfrom megatron.core import ModelParallelConfig, tensor_parallel\nfrom megatron.core import parallel_state as mpu\nfrom torch import nn\nfrom transformers import Qwen2Config\n\nfrom siirl.models.qwen2.megatron.layers.parallel_linear import QKVParallelLinear\nfrom siirl.utils.megatron import tensor_parallel as tp_utils\n\n\nclass Qwen2RotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        # Build here to make `torch.jit.trace` work.\n        self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype())\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n    def forward(self, x, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        if seq_len > self.max_seq_len_cached:\n            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)\n\n        return (\n            self.cos_cached[:seq_len].to(dtype=x.dtype),\n            self.sin_cached[:seq_len].to(dtype=x.dtype),\n        )\n\n\nclass Qwen2LinearScalingRotaryEmbedding(Qwen2RotaryEmbedding):\n    \"\"\"Qwen2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev\"\"\"\n\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n        t = t / self.scaling_factor\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\nclass Qwen2DynamicNTKScalingRotaryEmbedding(Qwen2RotaryEmbedding):\n    \"\"\"Qwen2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla\"\"\"\n\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n\n        if seq_len > self.max_position_embeddings:\n            base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))\n            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n            self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids):\n    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]\n    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass ParallelQwen2Attention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config = config\n        self.megatron_config = megatron_config\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n\n        # assign values after tp\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n        assert self.num_heads % tp_size == 0, f\"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}\"\n        assert self.num_key_value_heads % tp_size == 0, f\"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}\"\n\n        self.num_heads_per_tp = self.num_heads // tp_size\n        self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size\n        self.hidden_size_per_tp = self.hidden_size // tp_size\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads}).\")\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()\n\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            assert row_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)\n            tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)\n\n        # [self.q_size, self.k_size, self.v_size]\n        self.qkv_proj = QKVParallelLinear(\n            input_size=self.hidden_size,\n            num_heads=self.num_heads,\n            num_key_value_heads=self.num_key_value_heads,\n            head_dim=self.head_dim,\n            # bias=config.attention_bias,\n            bias=True,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n        self.q_size = self.num_heads_per_tp * self.head_dim\n        self.k_size = self.num_key_value_heads_per_tp * self.head_dim\n        self.v_size = self.num_key_value_heads_per_tp * self.head_dim\n\n        self.o_proj = tensor_parallel.RowParallelLinear(\n            input_size=self.num_heads * self.head_dim,\n            output_size=self.hidden_size,\n            # bias=config.attention_bias,\n            bias=False,\n            input_is_parallel=True,\n            skip_bias_add=False,\n            **row_kwargs,\n        )\n\n        self._init_rope()\n\n    def _init_rope(self):\n        self.rotary_emb = Qwen2RotaryEmbedding(\n            self.head_dim,\n            max_position_embeddings=self.max_position_embeddings,\n            base=self.rope_theta,\n        )\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n        qkv = self.qkv_proj(hidden_states)[0]\n        query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len):\n            raise ValueError(f\"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is {attn_weights.size()}\")\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\")\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim):\n            raise ValueError(f\"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is {attn_output.size()}\")\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp)\n        attn_output = self.o_proj(attn_output)[0]\n        return attn_output\n\n\n\"\"\"\nRemove padding Attention\n- Using Flash-attn 2\n- Compatible with sequence parallel\n\"\"\"\n\n\ndef apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length):\n    batch_size = position_ids.shape[0]\n\n    q = pad_input(q, indices, batch_size, sequence_length)  # (batch_size, seqlen, num_head, head_dim)\n    k = pad_input(k, indices, batch_size, sequence_length)\n    cos = cos[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]\n    sin = sin[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n\n    q_embed = index_first_axis(rearrange(q_embed, \"b s ... -> (b s) ...\"), indices)\n    k_embed = index_first_axis(rearrange(k_embed, \"b s ... -> (b s) ...\"), indices)\n\n    return q_embed, k_embed\n\n\n# use flash-attn rotary embeddings with rmpad\n# cos/sin shoudl be: (seq_length, rotary_dim / 2)\ndef apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen):\n    q_embed = apply_rotary_emb(q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)\n    k_embed = apply_rotary_emb(k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)\n    return q_embed, k_embed\n\n\nclass ParallelQwen2AttentionRmPad(ParallelQwen2Attention):\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: torch.Tensor = None,\n        max_seqlen_in_batch: int = None,\n    ):\n        total_nnz, _, _ = hidden_states.size()  # This is the total_nnz padded after sequence parallel\n\n        if self.megatron_config.sequence_parallel:\n            total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size()\n\n        qkv = self.qkv_proj(hidden_states)[0]\n        query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)  # (total_nnz, 1, hidden_size)\n\n        if self.megatron_config.sequence_parallel:\n            sequence_parallel_pad = total_nnz - cu_seqlens[-1]\n            total_nnz = cu_seqlens[-1]  # total_nnz before sp padding\n            query_states = query_states[:total_nnz]\n            key_states = key_states[:total_nnz]\n            value_states = value_states[:total_nnz]\n\n        # Flash attention requires the input to have the shape\n        # batch_size x seq_length x head_dime x hidden_dim\n        # therefore we just need to keep the original shape\n        query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim)\n        key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)\n        value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)\n\n        cos, sin = self.rotary_emb(value_states, seq_len=sequence_length)\n        cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2]  # flash attn only needs half\n        query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch)\n        # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices,\n\n        # It is recommended to use dropout with FA according to the docs\n        # when training.\n        dropout_rate = 0.0  # if not self.training else self.attn_dropout\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in float16 just to be sure everything works as expected.\n        # This might slowdown training & inference so it is recommended to not cast the LayerNorms\n        # in fp32. (Qwen2RMSNorm handles it correctly)\n        input_dtype = query_states.dtype\n        if input_dtype == torch.float32:\n            query_states = query_states.to(torch.float16)\n            key_states = key_states.to(torch.float16)\n            value_states = value_states.to(torch.float16)\n\n        attn_output_unpad = flash_attn_varlen_func(\n            query_states,\n            key_states,\n            value_states,\n            cu_seqlens_q=cu_seqlens,\n            cu_seqlens_k=cu_seqlens,\n            max_seqlen_q=max_seqlen_in_batch,\n            max_seqlen_k=max_seqlen_in_batch,\n            dropout_p=dropout_rate,\n            softmax_scale=None,\n            causal=True,\n        )\n\n        attn_output_unpad = attn_output_unpad.to(input_dtype)\n        attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous()\n\n        # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled\n        # Here we need to repad\n        if self.megatron_config.sequence_parallel:\n            attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad))\n\n        attn_output_unpad = self.o_proj(attn_output_unpad)[0]\n        return attn_output_unpad\n"
  },
  {
    "path": "siirl/models/qwen2/megatron/layers/parallel_decoder.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional, Tuple\n\nimport torch\nfrom megatron.core import ModelParallelConfig\nfrom torch import nn\nfrom transformers import Qwen2Config\n\nfrom siirl.utils.megatron.megatron_utils import TransformerConfig, convert_config\n\nfrom .parallel_attention import ParallelQwen2Attention, ParallelQwen2AttentionRmPad\nfrom .parallel_mlp import ParallelQwen2MLP\nfrom .parallel_rmsnorm import ParallelQwen2RMSNorm\n\n\nclass ParallelQwen2DecoderLayer(nn.Module):\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.layer_idx = layer_idx\n        self.hidden_size = config.hidden_size\n        self.self_attn = ParallelQwen2Attention(config=config, megatron_config=megatron_config)\n\n        self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config)\n        self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config)\n        self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Note: sequence parallel is hidden inside ColumnParallelLinear\n        # reduce scatter is hidden inside RowParallelLinear\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n        # TODO: add sequence parallel operator reduce_scatter here\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n\n        # TODO: add sequence parallel operator all_gather here\n\n        hidden_states = self.mlp(hidden_states)\n\n        # TODO: add sequence parallel operator reduce_scatter here\n\n        hidden_states = residual + hidden_states\n\n        outputs = hidden_states\n\n        return outputs\n\n\nclass ParallelQwen2DecoderLayerRmPad(nn.Module):\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.hidden_size = config.hidden_size\n        self.layer_idx = layer_idx\n        self.self_attn = ParallelQwen2AttentionRmPad(config=config, megatron_config=megatron_config)\n\n        self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config)\n        self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config)\n        self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        residual = hidden_states  # (total_nnz // sp, 1, hidden_size)\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size)\n        # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size)\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        # shape changes same as attn\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = hidden_states\n\n        return outputs\n"
  },
  {
    "path": "siirl/models/qwen2/megatron/layers/parallel_linear.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The vLLM team.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py\n\n\nfrom megatron.core import tensor_parallel\n\n\nclass QKVParallelLinear(tensor_parallel.ColumnParallelLinear):\n    def __init__(\n        self,\n        input_size,\n        num_heads,\n        num_key_value_heads,\n        head_dim,\n        *,\n        bias=True,\n        gather_output=True,\n        skip_bias_add=False,\n        **kwargs,\n    ):\n        # Keep input parameters, and already restrict the head numbers\n        self.input_size = input_size\n        self.q_output_size = num_heads * head_dim\n        self.kv_output_size = num_key_value_heads * head_dim\n        self.head_dim = head_dim\n        self.gather_output = gather_output\n        self.skip_bias_add = skip_bias_add\n\n        input_size = self.input_size\n        output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim\n\n        super().__init__(\n            input_size=input_size,\n            output_size=output_size,\n            bias=bias,\n            gather_output=gather_output,\n            skip_bias_add=skip_bias_add,\n            **kwargs,\n        )\n\n\nclass MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):\n    def __init__(\n        self,\n        input_size,\n        gate_ouput_size,\n        up_output_size,\n        *,\n        bias=True,\n        gather_output=True,\n        skip_bias_add=False,\n        **kwargs,\n    ):\n        # Keep input parameters, and already restrict the head numbers\n        self.input_size = input_size\n        self.output_size = gate_ouput_size + up_output_size\n        self.gather_output = gather_output\n        self.skip_bias_add = skip_bias_add\n\n        super().__init__(\n            input_size=self.input_size,\n            output_size=self.output_size,\n            bias=bias,\n            gather_output=gather_output,\n            skip_bias_add=skip_bias_add,\n            **kwargs,\n        )\n"
  },
  {
    "path": "siirl/models/qwen2/megatron/layers/parallel_mlp.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom megatron.core import ModelParallelConfig, tensor_parallel\nfrom megatron.core import parallel_state as mpu\nfrom torch import nn\nfrom transformers.activations import ACT2FN\n\nfrom siirl.models.qwen2.megatron.layers.parallel_linear import MergedColumnParallelLinear\nfrom siirl.utils.megatron import tensor_parallel as tp_utils\n\n\nclass ParallelQwen2MLP(nn.Module):\n    def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        # The weight is only [hidden_size, intermediate_size // model_parallel_world_size]\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()\n\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            assert row_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)\n            tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)\n\n        tp_size = mpu.get_tensor_model_parallel_world_size()\n\n        self.gate_up_proj = MergedColumnParallelLinear(\n            input_size=self.hidden_size,\n            gate_ouput_size=self.intermediate_size,\n            up_output_size=self.intermediate_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n        self.gate_size = self.intermediate_size // tp_size\n\n        self.down_proj = tensor_parallel.RowParallelLinear(\n            input_size=self.intermediate_size,\n            output_size=self.hidden_size,\n            bias=False,\n            input_is_parallel=True,\n            skip_bias_add=False,\n            **row_kwargs,\n        )\n\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        gate_up = self.gate_up_proj(x)[0]\n        gate, up = gate_up.split(self.gate_size, dim=-1)\n        return self.down_proj(self.act_fn(gate) * up)[0]\n"
  },
  {
    "path": "siirl/models/qwen2/megatron/layers/parallel_rmsnorm.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numbers\n\nimport torch\nfrom apex.normalization.fused_layer_norm import fused_rms_norm_affine\nfrom megatron.core import ModelParallelConfig\nfrom torch import nn\nfrom transformers import Qwen2Config\n\nfrom siirl.utils.megatron import sequence_parallel as sp_utils\n\n\nclass ParallelQwen2RMSNorm(nn.Module):\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        \"\"\"\n        Qwen2RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        if isinstance(config.hidden_size, numbers.Integral):\n            normalized_shape = (config.hidden_size,)\n        self.normalized_shape = torch.Size(normalized_shape)\n        self.weight = nn.Parameter(torch.ones(self.normalized_shape))\n        self.variance_epsilon = config.rms_norm_eps\n\n        if megatron_config.sequence_parallel:\n            sp_utils.mark_parameter_as_sequence_parallel(self.weight)\n\n    def forward(self, hidden_states):\n        return fused_rms_norm_affine(\n            input=hidden_states,\n            weight=self.weight,\n            normalized_shape=self.normalized_shape,\n            eps=self.variance_epsilon,\n            memory_efficient=True,\n        )\n"
  },
  {
    "path": "siirl/models/qwen2/megatron/modeling_qwen2_megatron.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Qwen2 model.\"\"\"\n\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.utils.checkpoint\nfrom megatron.core import ModelParallelConfig, mpu, parallel_state, tensor_parallel\nfrom torch import nn\nfrom transformers.modeling_outputs import BaseModelOutputWithPast\nfrom transformers.models.qwen2.configuration_qwen2 import Qwen2Config\nfrom transformers.models.qwen2.modeling_qwen2 import CausalLMOutputWithPast\n\nfrom siirl.utils.extras.device import get_device_id, get_torch_device\nfrom siirl.utils.megatron import sequence_parallel as sp_utils\nfrom siirl.utils.megatron import tensor_parallel as tp_utils\nfrom siirl.utils.megatron.megatron_utils import TransformerConfig, convert_config\n\nfrom .layers import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad, ParallelQwen2RMSNorm\n\n\"\"\"\nTODO: \n1. Add weight initialization. Here we need to be careful on TP weight init.\n2. Add sequence parallel\n3. Load checkpoint from Qwen2 pretrained checkpoint\n\"\"\"\n\n\n# Copied from transformers.models.bart.modeling_bart._make_causal_mask\ndef _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)\n\n\n# Copied from transformers.models.bart.modeling_bart._expand_mask\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\n\nclass ParallelQwen2Model(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]\n\n    Args:\n        config: Qwen2Config\n    \"\"\"\n\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, megatron_config)\n        self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs)\n\n        self.layers = nn.ModuleList([ParallelQwen2DecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)])\n        self.norm = ParallelQwen2RMSNorm(config, megatron_config)\n\n    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask\n    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds):\n        # create causal mask\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        combined_attention_mask = None\n        if input_shape[-1] > 1:\n            combined_attention_mask = _make_causal_mask(\n                input_shape,\n                inputs_embeds.dtype,\n                device=inputs_embeds.device,\n            )\n\n        if attention_mask is not None:\n            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)\n            combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n\n        return combined_attention_mask\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (batch_size, seq_length)\n            attention_mask: attention_mask. shape (batch_size, seq_length)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        batch_size, seq_length = input_ids.shape\n        inputs_embeds = self.embed_tokens(input_ids)\n        # embed positions\n\n        attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds)\n\n        hidden_states = inputs_embeds\n\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n            )\n\n            hidden_states = layer_outputs\n\n        hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelQwen2ForCausalLM(nn.Module):\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.model = ParallelQwen2Model(config, megatron_config=megatron_config)\n        self.vocab_size = config.vocab_size\n\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n        hidden_states = outputs\n        logits = self.lm_head(hidden_states)[0]\n\n        logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)\n\n        logits = logits.float()\n        return CausalLMOutputWithPast(\n            loss=None,\n            logits=logits,\n            past_key_values=None,\n            hidden_states=None,\n            attentions=None,\n        )\n\n\nfrom flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\n\n\nclass ParallelQwen2ModelRmPad(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]\n\n    Args:\n        config: Qwen2Config\n    \"\"\"\n\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        self.megatron_config = megatron_config\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)\n        self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs)\n\n        self.layers = nn.ModuleList([ParallelQwen2DecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)])\n        self.norm = ParallelQwen2RMSNorm(config, megatron_config)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (1, totol_nnz)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        inputs_embeds = self.embed_tokens(input_ids)  # (1, total_nnz) -> (1, total_nnz, hidden_size)\n\n        # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)\n        inputs_embeds = inputs_embeds.transpose(0, 1)\n        if self.megatron_config.sequence_parallel:\n            inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)\n\n        hidden_states = inputs_embeds\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                position_ids=position_ids,\n                sequence_length=sequence_length,\n                indices=indices,\n                cu_seqlens=cu_seqlens,\n                max_seqlen_in_batch=max_seqlen_in_batch,\n            )\n\n            hidden_states = layer_outputs\n\n        hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelQwen2ForCausalLMRmPad(nn.Module):\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.megatron_config = megatron_config\n        self.model = ParallelQwen2ModelRmPad(config, megatron_config=megatron_config)\n        self.vocab_size = config.vocab_size\n        self._init_head(config)\n\n    def _init_head(self, config: Qwen2Config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            **column_kwargs,\n        )\n\n    def _forward_head(self, hidden_states):\n        # all_gather from sequence parallel region is performed inside lm_head\n        logits = self.lm_head(hidden_states)[0]\n        logits = logits.float()  # (total_nnz_padded, 1, vocab_size // tp)\n        logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)  # (total_nnz_padded, 1, vocab_size)\n        return logits\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n        batch_size, sequence_length = input_ids.shape\n\n        # remove padding here\n        input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)  # (total_nnz, 1)\n\n        # pad input_ids to multiple of tp for all tp ranks\n        # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap\n        if self.megatron_config.sequence_parallel:\n            input_ids = sp_utils.pad_to_sequence_parallel(input_ids)\n\n        input_ids = input_ids.transpose(0, 1)  # (1, total_nnz+pad)\n\n        outputs = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        hidden_states = outputs\n\n        logits = self._forward_head(hidden_states)\n\n        # remove padding from sequence parallel\n        if self.megatron_config.sequence_parallel:\n            totol_nnz = cu_seqlens[-1]\n            logits = logits[:totol_nnz]  # (total_nnz_padded)\n\n        logits = torch.squeeze(logits, dim=1)  # remove the artificial batch dimension\n        # add removed padding back\n        logits = pad_input(logits, indices, batch_size, seqlen=sequence_length)  # (batch_size, sequence_length, vocab_size)\n\n        return CausalLMOutputWithPast(\n            loss=None,\n            logits=logits,\n            past_key_values=None,\n            hidden_states=None,\n            attentions=None,\n        )\n\n\nclass ParallelQwen2ForValueRmPad(ParallelQwen2ForCausalLMRmPad):\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)\n        # lm_head is effectively the same as sequence parallel\n        sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)\n\n    def _forward_head(self, hidden_states):\n        logits = self.lm_head(hidden_states)  # (total_nnz_padded // tp, 1, 1)\n        logits = logits.float()\n        if self.megatron_config.sequence_parallel:\n            logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)\n        return logits\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        output = super().forward(input_ids, attention_mask, position_ids)\n        output.logits = torch.squeeze(output.logits, dim=-1)\n        return output\n\n\n\"\"\"\nSupport pipeline parallelism\n\"\"\"\n\n\nclass ParallelQwen2ModelRmPadPP(nn.Module):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]\n    This model definition supports pipeline parallelism. To support pp and vpp,\n    - This model only contains layer in this pp stage and vpp chunk\n    - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp.\n    Args:\n        config: Qwen2Config\n    \"\"\"\n\n    def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pre_process, post_process):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.pre_process = pre_process\n        self.post_process = post_process\n        self.megatron_config = megatron_config\n        embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()\n        if megatron_config is not None:\n            assert embedding_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)\n        if pre_process:\n            self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs)\n        else:\n            self.embed_tokens = None\n\n        pp_rank = mpu.get_pipeline_model_parallel_rank()\n        pp_size = megatron_config.pipeline_model_parallel_size\n        self.num_layer_per_pp = config.num_hidden_layers // pp_size\n        vpp_size = megatron_config.virtual_pipeline_model_parallel_size\n        vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank()\n\n        if vpp_size is not None:\n            self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size\n            self.num_layer_this_model = self.num_layer_vpp_chunk\n            offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk)\n        else:\n            self.num_layer_this_model = self.num_layer_per_pp\n            offset = pp_rank * self.num_layer_per_pp\n\n        self.layers = nn.ModuleList()\n        for i in range(self.num_layer_this_model):\n            layer = ParallelQwen2DecoderLayerRmPad(config, megatron_config, layer_idx=i + offset)\n            self.layers.add_module(f\"{i}\", layer)\n\n        if post_process:\n            self.norm = ParallelQwen2RMSNorm(config, megatron_config)\n        else:\n            self.norm = None\n\n    def set_input_tensor(self, input_tensor):\n        \"\"\"Set input tensor to be used instead of forward()'s input.\n\n        When doing pipeline parallelism the input from the previous\n        stage comes from communication, not from the input, so the\n        model's forward_step_func won't have it. This function is thus\n        used by internal code to bypass the input provided by the\n        forward_step_func\"\"\"\n        self.input_tensor = input_tensor\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        sequence_length: int = None,\n        indices: torch.Tensor = None,\n        cu_seqlens: int = None,\n        max_seqlen_in_batch: int = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        \"\"\"\n\n        Args:\n            input_ids: input ids. shape (1, totol_nnz)\n            position_ids: position ids. shape (batch_size, seq_length)\n\n        Returns:\n\n        \"\"\"\n        if self.pre_process:\n            inputs_embeds = self.embed_tokens(input_ids)  # (1, total_nnz) -> (1, total_nnz, hidden_size)\n\n            # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron\n            # so need to deal with it by handle here:\n            # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)\n            inputs_embeds = inputs_embeds.transpose(0, 1)\n            if self.megatron_config.sequence_parallel:\n                inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)\n\n            hidden_states = inputs_embeds\n        else:\n            # self.hidden_states should be passed by Megatron\n            hidden_states = self.input_tensor\n\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                position_ids=position_ids,\n                sequence_length=sequence_length,\n                indices=indices,\n                cu_seqlens=cu_seqlens,\n                max_seqlen_in_batch=max_seqlen_in_batch,\n            )\n\n            hidden_states = layer_outputs\n\n        if self.post_process:\n            hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass ParallelQwen2ForCausalLMRmPadPP(nn.Module):\n    def __init__(\n        self,\n        config: Qwen2Config,\n        megatron_config: ModelParallelConfig,\n        pre_process,\n        post_process,\n        share_embeddings_and_output_weights,\n    ):\n        super().__init__()\n        self.config: TransformerConfig = convert_config(config, megatron_config)\n        self.megatron_config = megatron_config\n        self.model = ParallelQwen2ModelRmPadPP(config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process)\n        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights\n        self.vocab_size = config.vocab_size\n        self.pre_process = pre_process\n        self.post_process = post_process\n        if post_process:\n            self._init_head(config)\n        if pre_process or post_process:\n            self.setup_embeddings_and_output_layer()\n\n    def set_input_tensor(self, input_tensor):\n        \"\"\"Set input tensor to be used instead of forward()'s input.\n\n        When doing pipeline parallelism the input from the previous\n        stage comes from communication, not from the input, so the\n        model's forward_step_func won't have it. This function is thus\n        used by internal code to bypass the input provided by the\n        forward_step_func\"\"\"\n        assert len(input_tensor) == 1\n        self.model.set_input_tensor(input_tensor[0])\n\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = tensor_parallel.ColumnParallelLinear(\n            input_size=config.hidden_size,\n            output_size=config.vocab_size,\n            bias=False,\n            gather_output=False,\n            skip_bias_add=False,\n            skip_weight_param_allocation=self.pre_process and self.share_embeddings_and_output_weights,\n            **column_kwargs,\n        )\n\n    def setup_embeddings_and_output_layer(self) -> None:\n        \"\"\"Sets up embedding layer in first stage and output layer in last stage.\n\n        This function initalizes word embeddings in the final stage when we are\n        using pipeline parallelism and sharing word embeddings, and sets up param\n        attributes on the embedding and output layers.\n        \"\"\"\n        # Set `is_embedding_or_output_parameter` attribute.\n        if self.pre_process:\n            self.model.embed_tokens.weight.is_embedding_or_output_parameter = True\n        if self.post_process and self.lm_head.weight is not None:\n            self.lm_head.weight.is_embedding_or_output_parameter = True\n\n        if not self.share_embeddings_and_output_weights:\n            return\n\n        if parallel_state.get_pipeline_model_parallel_world_size() == 1:\n            # Zero out wgrad if sharing embeddings between two layers on same\n            # pipeline stage to make sure grad accumulation into main_grad is\n            # correct and does not include garbage values (e.g., from torch.empty).\n            self.shared_embedding_or_output_weight().zero_out_wgrad = True\n            return\n\n        if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process:\n            self.shared_embedding_or_output_weight().shared_embedding = True\n\n        if self.post_process and not self.pre_process:\n            assert not parallel_state.is_pipeline_first_stage()\n            # set word_embeddings weights to 0 here, then copy first\n            # stage's weights using all_reduce below.\n            self.lm_head.weight.data.fill_(0)\n            self.lm_head.weight.shared = True\n            self.lm_head.weight.shared_embedding = True\n\n        if torch.distributed.is_initialized() and parallel_state.is_rank_in_embedding_group():\n            weight = self.shared_embedding_or_output_weight()\n            weight.data = weight.data.to(get_device_name())\n            torch.distributed.all_reduce(weight.data, group=parallel_state.get_embedding_group())\n\n    def shared_embedding_or_output_weight(self) -> torch.Tensor:\n        if self.pre_process:\n            return self.model.embed_tokens.weight\n        elif self.post_process:\n            return self.lm_head.weight\n        return None\n\n    def _forward_head(self, hidden_states):\n        # all_gather from sequence parallel region is performed inside lm_head\n        # print(f'logits shape before forward_head: {hidden_states.shape}, vocab_size = {self.config.vocab_size}') # [4, 32, 4096]\n        output_weight = None\n        if self.share_embeddings_and_output_weights:\n            output_weight = self.shared_embedding_or_output_weight()\n        logits = self.lm_head(hidden_states, weight=output_weight)[0]\n        # print(f'logits shape after forward_head: {logits.shape}') # [8, 32, 8]\n        logits = logits.float()  # (total_nnz_padded, 1, vocab_size // tp)\n        return logits\n\n    def forward(\n        self,\n        # original input\n        *,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n        ```\"\"\"\n\n        # Note that input_ids, attention_mask and position_ids should be passed to every pp layer.\n        # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model\n        batch_size, sequence_length = input_ids.shape\n        # remove padding here\n        input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)  # (total_nnz, 1)\n\n        # pad input_ids to multiple of tp for all tp ranks\n        # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap\n        if self.megatron_config.sequence_parallel:\n            input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad)\n\n        input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz+pad)\n\n        outputs = self.model(\n            input_ids=input_ids_rmpad,\n            position_ids=position_ids,\n            sequence_length=sequence_length,\n            indices=indices,\n            cu_seqlens=cu_seqlens,\n            max_seqlen_in_batch=max_seqlen_in_batch,\n        )\n\n        if self.post_process:\n            hidden_states = outputs\n            logits = self._forward_head(hidden_states)\n            logits = torch.squeeze(logits, dim=1)  # remove the artificial batch dimension # torch.Size([8, 32, 16])\n\n            # remove padding from sequence parallel\n            if self.megatron_config.sequence_parallel:\n                totol_nnz = cu_seqlens[-1]\n                logits = logits[:totol_nnz]  # (total_nnz_padded)\n            # add removed padding back. If input is already rmpad, we let the caller pad_input\n            logits = pad_input(logits, indices, batch_size, seqlen=sequence_length)  # (batch_size, sequence_length, vocab_size)\n\n            return CausalLMOutputWithPast(\n                loss=None,\n                logits=logits,\n                past_key_values=None,\n                hidden_states=None,\n                attentions=None,\n            )\n        else:\n            return outputs\n\n\nclass ParallelQwen2ForValueRmPadPP(ParallelQwen2ForCausalLMRmPadPP):\n    def _init_head(self, config):\n        column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()\n        if self.megatron_config is not None:\n            assert column_kwargs.get(\"config\", False), \"must have ModelParallelConfig\"\n            tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)\n        self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)\n        # lm_head is effectively the same as sequence parallel\n        sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)\n\n    def _forward_head(self, hidden_states):\n        logits = self.lm_head(hidden_states)  # (total_nnz_padded // tp, 1, 1)\n        logits = logits.float()\n        if self.megatron_config.sequence_parallel:\n            logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)\n        return logits\n\n    def forward(\n        self,\n        *,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)\n        if self.post_process:\n            output.logits = torch.squeeze(output.logits, dim=-1)\n            return output\n        else:\n            return output\n"
  },
  {
    "path": "siirl/models/registry.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib\nfrom typing import List, Optional, Type\n\nimport torch.nn as nn\n\n# Supported models in Megatron-LM\n# Architecture -> (module, class).\n_MODELS = {\n    \"LlamaForCausalLM\": (\n        \"llama\",\n        (\"ParallelLlamaForCausalLMRmPadPP\", \"ParallelLlamaForValueRmPadPP\", \"ParallelLlamaForCausalLMRmPad\"),\n    ),\n    \"Qwen2ForCausalLM\": (\n        \"qwen2\",\n        (\"ParallelQwen2ForCausalLMRmPadPP\", \"ParallelQwen2ForValueRmPadPP\", \"ParallelQwen2ForCausalLMRmPad\"),\n    ),\n    \"MistralForCausalLM\": (\n        \"mistral\",\n        (\"ParallelMistralForCausalLMRmPadPP\", \"ParallelMistralForValueRmPadPP\", \"ParallelMistralForCausalLMRmPad\"),\n    ),\n}\n\n\n# return model class\nclass ModelRegistry:\n    @staticmethod\n    def load_model_cls(model_arch: str, value=False) -> Optional[Type[nn.Module]]:\n        if model_arch not in _MODELS:\n            return None\n\n        megatron = \"megatron\"\n\n        module_name, model_cls_name = _MODELS[model_arch]\n        if not value:  # actor/ref\n            model_cls_name = model_cls_name[0]\n        elif value:  # critic/rm\n            model_cls_name = model_cls_name[1]\n\n        module = importlib.import_module(f\"siirl.models.{module_name}.{megatron}.modeling_{module_name}_megatron\")\n        return getattr(module, model_cls_name, None)\n\n    @staticmethod\n    def get_supported_archs() -> List[str]:\n        return list(_MODELS.keys())\n"
  },
  {
    "path": "siirl/models/transformers/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/models/transformers/internvl.py",
    "content": "# --------------------------------------------------------\n# InternVL\n# Copyright (c) 2024 OpenGVLab\n# Licensed under The MIT License [see LICENSE for details]\n# --------------------------------------------------------\n\n# --------------------------------------------------------\n# InternVL\n# Copyright (c) 2024 OpenGVLab\n# Licensed under The MIT License [see LICENSE for details]\n# --------------------------------------------------------\n\nimport io\nimport dataclasses\nfrom enum import IntEnum, auto\n\nfrom transformers.trainer_pt_utils import LabelSmoother\n\nIGNORE_TOKEN_ID = LabelSmoother.ignore_index\nimport os\nimport random\nimport re\nfrom collections import Counter\nfrom typing import Dict, Tuple, List, Union\n\nimport imageio\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torchvision.transforms as T\nimport transformers\nfrom PIL import Image\nfrom torch.utils.data import ConcatDataset, WeightedRandomSampler\nfrom torchvision.transforms.functional import InterpolationMode\nfrom loguru import logger\n\n\nIMG_CONTEXT_TOKEN = \"<IMG_CONTEXT>\"\nIMG_START_TOKEN = \"<img>\"\nIMG_END_TOKEN = \"</img>\"\nQUAD_START_TOKEN = \"<quad>\"\nQUAD_END_TOKEN = \"</quad>\"\nREF_START_TOKEN = \"<ref>\"\nREF_END_TOKEN = \"</ref>\"\nBOX_START_TOKEN = \"<box>\"\nBOX_END_TOKEN = \"</box>\"\nIMAGENET_MEAN = (0.485, 0.456, 0.406)\nIMAGENET_STD = (0.229, 0.224, 0.225)\nCLIP_MEAN = (0.4814546, 0.4578275, 0.40821073)\nCLIP_STD = (0.2686295, 0.2613025, 0.2757711)\nSIGLIP_MEAN = (0.5, 0.5, 0.5)\nSIGLIP_STD = (0.5, 0.5, 0.5)\n\n\nclass SeparatorStyle(IntEnum):\n    \"\"\"Separator styles.\"\"\"\n\n    ADD_COLON_SINGLE = auto()\n    ADD_COLON_TWO = auto()\n    ADD_COLON_SPACE_SINGLE = auto()\n    NO_COLON_SINGLE = auto()\n    NO_COLON_TWO = auto()\n    ADD_NEW_LINE_SINGLE = auto()\n    LLAMA2 = auto()\n    CHATGLM = auto()\n    CHATML = auto()\n    CHATINTERN = auto()\n    DOLLY = auto()\n    RWKV = auto()\n    PHOENIX = auto()\n    ROBIN = auto()\n    FALCON_CHAT = auto()\n    CHATGLM3 = auto()\n    INTERNVL_ZH = auto()\n    MPT = auto()\n\n\n@dataclasses.dataclass\nclass Conversation:\n    \"\"\"A class that manages prompt templates and keeps all conversation history.\"\"\"\n\n    # The name of this template\n    name: str\n    # The template of the system prompt\n    system_template: str = \"{system_message}\"\n    # The system message\n    system_message: str = \"\"\n    # The names of two roles\n    roles: Tuple[str] = (\"USER\", \"ASSISTANT\")\n    # All messages. Each item is (role, message).\n    messages: List[List[str]] = ()\n    # The number of few shot examples\n    offset: int = 0\n    # The separator style and configurations\n    sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE\n    sep: str = \"\\n\"\n    sep2: str = None\n    # Stop criteria (the default one is EOS token)\n    stop_str: Union[str, List[str]] = None\n    # Stops generation if meeting any token in this list\n    stop_token_ids: List[int] = None\n\n    def get_prompt(self) -> str:\n        \"\"\"Get the prompt for generation.\"\"\"\n        system_prompt = self.system_template.format(system_message=self.system_message)\n        if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:\n            ret = system_prompt + self.sep\n            for role, message in self.messages:\n                if message:\n                    ret += role + \": \" + message + self.sep\n                else:\n                    ret += role + \":\"\n            return ret\n        elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:\n            seps = [self.sep, self.sep2]\n            ret = system_prompt + seps[0]\n            for i, (role, message) in enumerate(self.messages):\n                if message:\n                    ret += role + \": \" + message + seps[i % 2]\n                else:\n                    ret += role + \":\"\n            return ret\n        elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:\n            ret = system_prompt + self.sep\n            for role, message in self.messages:\n                if message:\n                    ret += role + \": \" + message + self.sep\n                else:\n                    ret += role + \": \"  # must be end with a space\n            return ret\n        elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:\n            ret = \"\" if system_prompt == \"\" else system_prompt + self.sep\n            for role, message in self.messages:\n                if message:\n                    ret += role + \"\\n\" + message + self.sep\n                else:\n                    ret += role + \"\\n\"\n            return ret\n        elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:\n            ret = system_prompt\n            for role, message in self.messages:\n                if message:\n                    ret += role + message + self.sep\n                else:\n                    ret += role\n            return ret\n        elif self.sep_style == SeparatorStyle.NO_COLON_TWO:\n            seps = [self.sep, self.sep2]\n            ret = system_prompt\n            for i, (role, message) in enumerate(self.messages):\n                if message:\n                    ret += role + message + seps[i % 2]\n                else:\n                    ret += role\n            return ret\n        elif self.sep_style == SeparatorStyle.RWKV:\n            ret = system_prompt\n            for i, (role, message) in enumerate(self.messages):\n                if message:\n                    ret += role + \": \" + message.replace(\"\\r\\n\", \"\\n\").replace(\"\\n\\n\", \"\\n\")\n                    ret += \"\\n\\n\"\n                else:\n                    ret += role + \":\"\n            return ret\n        elif self.sep_style == SeparatorStyle.LLAMA2:\n            seps = [self.sep, self.sep2]\n            if self.system_message:\n                ret = system_prompt\n            else:\n                ret = \"[INST] \"\n            for i, (role, message) in enumerate(self.messages):\n                tag = self.roles[i % 2]\n                if message:\n                    if i == 0:\n                        ret += message + \" \"\n                    else:\n                        ret += tag + \" \" + message + seps[i % 2]\n                else:\n                    ret += tag\n            return ret\n        elif self.sep_style == SeparatorStyle.CHATGLM:\n            # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308\n            # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926\n            round_add_n = 1 if self.name == \"chatglm2\" else 0\n            if system_prompt:\n                ret = system_prompt + self.sep\n            else:\n                ret = \"\"\n\n            for i, (role, message) in enumerate(self.messages):\n                if i % 2 == 0:\n                    ret += f\"[Round {i // 2 + round_add_n}]{self.sep}\"\n\n                if message:\n                    ret += f\"{role}：{message}{self.sep}\"\n                else:\n                    ret += f\"{role}：\"\n            return ret\n        elif self.sep_style == SeparatorStyle.CHATML:\n            ret = \"\" if system_prompt == \"\" else system_prompt + self.sep + \"\\n\"\n            for role, message in self.messages:\n                if message:\n                    ret += role + \"\\n\" + message + self.sep + \"\\n\"\n                else:\n                    ret += role + \"\\n\"\n            return ret\n        elif self.sep_style == SeparatorStyle.CHATGLM3:\n            ret = \"\"\n            if self.system_message:\n                ret += system_prompt\n            for role, message in self.messages:\n                if message:\n                    ret += role + \"\\n\" + \" \" + message\n                else:\n                    ret += role\n            return ret\n        elif self.sep_style == SeparatorStyle.CHATINTERN:\n            # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771\n            seps = [self.sep, self.sep2]\n            ret = system_prompt\n            for i, (role, message) in enumerate(self.messages):\n                # if i % 2 == 0:\n                #     ret += \"<s>\"\n                if message:\n                    ret += role + \":\" + message + seps[i % 2] + \"\\n\"\n                else:\n                    ret += role + \":\"\n            return ret\n        elif self.sep_style == SeparatorStyle.DOLLY:\n            seps = [self.sep, self.sep2]\n            ret = system_prompt\n            for i, (role, message) in enumerate(self.messages):\n                if message:\n                    ret += role + \":\\n\" + message + seps[i % 2]\n                    if i % 2 == 1:\n                        ret += \"\\n\\n\"\n                else:\n                    ret += role + \":\\n\"\n            return ret\n        elif self.sep_style == SeparatorStyle.PHOENIX:\n            ret = system_prompt\n            for role, message in self.messages:\n                if message:\n                    ret += role + \": \" + \"<s>\" + message + \"</s>\"\n                else:\n                    ret += role + \": \" + \"<s>\"\n            return ret\n        elif self.sep_style == SeparatorStyle.ROBIN:\n            ret = system_prompt + self.sep\n            for role, message in self.messages:\n                if message:\n                    ret += role + \":\\n\" + message + self.sep\n                else:\n                    ret += role + \":\\n\"\n            return ret\n        elif self.sep_style == SeparatorStyle.FALCON_CHAT:\n            ret = \"\"\n            if self.system_message:\n                ret += system_prompt + self.sep\n            for role, message in self.messages:\n                if message:\n                    ret += role + \": \" + message + self.sep\n                else:\n                    ret += role + \":\"\n\n            return ret\n        elif self.sep_style == SeparatorStyle.INTERNVL_ZH:\n            seps = [self.sep2, self.sep]\n            ret = self.system_message + seps[0]\n            for i, (role, message) in enumerate(self.messages):\n                if message:\n                    ret += role + \": \" + message + seps[i % 2]\n                else:\n                    ret += role + \":\"\n            return ret\n        elif self.sep_style == SeparatorStyle.MPT:\n            ret = system_prompt + self.sep\n            for role, message in self.messages:\n                if message:\n                    if type(message) is tuple:\n                        message, _, _ = message\n                    ret += role + message + self.sep\n                else:\n                    ret += role\n            return ret\n        else:\n            raise ValueError(f\"Invalid style: {self.sep_style}\")\n\n    def set_system_message(self, system_message: str):\n        \"\"\"Set the system message.\"\"\"\n        self.system_message = system_message\n\n    def append_message(self, role: str, message: str):\n        \"\"\"Append a new message.\"\"\"\n        self.messages.append([role, message])\n\n    def update_last_message(self, message: str):\n        \"\"\"Update the last output.\n\n        The last message is typically set to be None when constructing the prompt,\n        so we need to update it in-place after getting the response from a model.\n        \"\"\"\n        self.messages[-1][1] = message\n\n    def to_gradio_chatbot(self):\n        \"\"\"Convert the conversation to gradio chatbot format.\"\"\"\n        ret = []\n        for i, (role, msg) in enumerate(self.messages[self.offset :]):\n            if i % 2 == 0:\n                ret.append([msg, None])\n            else:\n                ret[-1][-1] = msg\n        return ret\n\n    def to_openai_api_messages(self):\n        \"\"\"Convert the conversation to OpenAI chat completion format.\"\"\"\n        ret = [{\"role\": \"system\", \"content\": self.system_message}]\n\n        for i, (_, msg) in enumerate(self.messages[self.offset :]):\n            if i % 2 == 0:\n                ret.append({\"role\": \"user\", \"content\": msg})\n            else:\n                if msg is not None:\n                    ret.append({\"role\": \"assistant\", \"content\": msg})\n        return ret\n\n    def copy(self):\n        return Conversation(\n            name=self.name,\n            system_template=self.system_template,\n            system_message=self.system_message,\n            roles=self.roles,\n            messages=[[x, y] for x, y in self.messages],\n            offset=self.offset,\n            sep_style=self.sep_style,\n            sep=self.sep,\n            sep2=self.sep2,\n            stop_str=self.stop_str,\n            stop_token_ids=self.stop_token_ids,\n        )\n\n    def dict(self):\n        return {\n            \"template_name\": self.name,\n            \"system_message\": self.system_message,\n            \"roles\": self.roles,\n            \"messages\": self.messages,\n            \"offset\": self.offset,\n        }\n\n\nconv_templates: Dict[str, Conversation] = {}\n\n\ndef register_conv_template(template: Conversation, override: bool = False):\n    \"\"\"Register a new conversation template.\"\"\"\n    if not override:\n        assert template.name not in conv_templates, f\"{template.name} has been registered.\"\n\n    conv_templates[template.name] = template\n\n\ndef get_conv_template(name: str) -> Conversation:\n    \"\"\"Get a conversation template.\"\"\"\n    return conv_templates[name].copy()\n\n\nregister_conv_template(\n    Conversation(\n        name=\"internvl2_5\",\n        system_template=\"<|im_start|>system\\n{system_message}\",\n        system_message=\"你是书生·万象，英文名是InternVL，是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。\",\n        roles=(\"<|im_start|>user\\n\", \"<|im_start|>assistant\\n\"),\n        sep_style=SeparatorStyle.MPT,\n        sep=\"<|im_end|>\\n\",\n    )\n)\n\n\ntry:\n    from petrel_client.client import Client\n    from petrel_client.common.config import Config\nexcept ImportError as E:\n    logger.debug(\"petrel_client is not installed. If you read data locally instead of from ceph, ignore it.\")\nimport sys\n\n\ndef calculate_ngram_repetition(text, n):\n    words = text.split()\n    ngrams = [tuple(words[i : i + n]) for i in range(len(words) - n + 1)]\n    ngram_counts = Counter(ngrams)\n    total_ngrams = len(ngrams)\n    repeated_ngrams = sum(1 for count in ngram_counts.values() if count > 1)\n    return repeated_ngrams / total_ngrams if total_ngrams > 0 else 0\n\n\ndef check_conversations_repetition(conversations, repeat_threshold=0.4, ngram=10):\n    for conversation in conversations:\n        if conversation[\"from\"] == \"gpt\":\n            model_answer = conversation[\"value\"]\n            repeat_ratio = calculate_ngram_repetition(model_answer, ngram)\n            if repeat_ratio > repeat_threshold:\n                raise Exception\n\n\ndef get_frame_indices(num_frames, vlen, sample=\"rand\", fix_start=None, input_fps=1, max_num_frames=-1):\n    if sample in [\"rand\", \"middle\"]:  # uniform sampling\n        acc_samples = min(num_frames, vlen)\n        # split the video into `acc_samples` intervals, and sample from each interval.\n        intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)\n        ranges = []\n        for idx, interv in enumerate(intervals[:-1]):\n            ranges.append((interv, intervals[idx + 1] - 1))\n        if sample == \"rand\":\n            try:\n                frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]\n            except:\n                frame_indices = np.random.permutation(vlen)[:acc_samples]\n                frame_indices.sort()\n                frame_indices = list(frame_indices)\n        elif fix_start is not None:\n            frame_indices = [x[0] + fix_start for x in ranges]\n        elif sample == \"middle\":\n            frame_indices = [(x[0] + x[1]) // 2 for x in ranges]\n        else:\n            raise NotImplementedError\n\n        if len(frame_indices) < num_frames:  # padded with last frame\n            padded_frame_indices = [frame_indices[-1]] * num_frames\n            padded_frame_indices[: len(frame_indices)] = frame_indices\n            frame_indices = padded_frame_indices\n    elif \"fps\" in sample:  # fps0.5, sequentially sample frames at 0.5 fps\n        output_fps = float(sample[3:])\n        duration = float(vlen) / input_fps\n        delta = 1 / output_fps  # gap between frames, this is also the clip length each frame represents\n        frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)\n        frame_indices = np.around(frame_seconds * input_fps).astype(int)\n        frame_indices = [e for e in frame_indices if e < vlen]\n        if max_num_frames > 0 and len(frame_indices) > max_num_frames:\n            frame_indices = frame_indices[:max_num_frames]\n            # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)\n    else:\n        raise ValueError\n    return frame_indices\n\n\ndef read_frames_gif(video_path, num_frames, sample=\"rand\", fix_start=None, client=None, min_num_frames=4):\n    if \"s3://\" in video_path:\n        video_bytes = client.get(video_path)\n        gif = imageio.get_reader(io.BytesIO(video_bytes))\n    else:\n        gif = imageio.get_reader(video_path)\n    vlen = len(gif)\n\n    t_num_frames = np.random.randint(min_num_frames, num_frames + 1)\n    frame_indices = get_frame_indices(t_num_frames, vlen, sample=sample, fix_start=fix_start)\n    frames = []\n    for index, frame in enumerate(gif):\n        if index in frame_indices:\n            import cv2\n            frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB).astype(np.uint8)\n            frame = Image.fromarray(frame)\n            frames.append(frame)\n    return frames\n\n\ndef read_frames_decord(video_path, num_frames, sample=\"rand\", fix_start=None, client=None, clip=None, min_num_frames=4):\n    from decord import VideoReader\n\n    if \"s3://\" in video_path:\n        video_bytes = client.get(video_path)\n        video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1)\n    else:\n        video_reader = VideoReader(video_path, num_threads=1)\n    vlen = len(video_reader)\n    fps = video_reader.get_avg_fps()\n    duration = vlen / float(fps)\n    if clip:\n        start, end = clip\n        duration = end - start\n        vlen = int(duration * fps)\n        start_index = int(start * fps)\n\n    # t_num_frames = min(max(int(duration * sample_fps), min_num_frames), num_frames)\n    t_num_frames = np.random.randint(min_num_frames, num_frames + 1)\n\n    frame_indices = get_frame_indices(t_num_frames, vlen, sample=sample, fix_start=fix_start, input_fps=fps)\n    if clip:\n        frame_indices = [f + start_index for f in frame_indices]\n    frames = video_reader.get_batch(frame_indices).asnumpy()  # (T, H, W, C), np.uint8\n    frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])]\n    return frames\n\n\ndef extract_frame_number(filename):\n    # Extract the numeric part from the filename using regular expressions\n    match = re.search(r\"_(\\d+).jpg$\", filename)\n    return int(match.group(1)) if match else -1\n\n\ndef sort_frames(frame_paths):\n    # Extract filenames from each path and sort by their numeric part\n    return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x)))\n\n\ndef read_frames_folder(video_path, num_frames, sample=\"rand\", fix_start=None, client=None, clip=None, min_num_frames=4):\n    if \"s3://\" in video_path:\n        image_list = sort_frames(client.list(video_path))\n        frames = []\n        for image in image_list:\n            fp = os.path.join(video_path, image)\n            frame = Image.open(io.BytesIO(client.get(fp)))\n            frames.append(frame)\n    else:\n        image_list = sort_frames(list(os.listdir(video_path)))\n        frames = []\n        for image in image_list:\n            fp = os.path.join(video_path, image)\n            frame = Image.open(fp).convert(\"RGB\")\n            frames.append(frame)\n    vlen = len(frames)\n\n    t_num_frames = np.random.randint(min_num_frames, num_frames + 1)\n\n    if vlen > t_num_frames:\n        frame_indices = get_frame_indices(t_num_frames, vlen, sample=sample, fix_start=fix_start)\n        frames = [frames[i] for i in frame_indices]\n    return frames\n\n\nclass WeightedConcatDataset(ConcatDataset):\n    def __init__(self, datasets, weights):\n        super().__init__(datasets)\n        self.weights = torch.DoubleTensor(weights)\n        self.total_size = sum(len(d) for d in datasets)\n        self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True)\n\n    def __iter__(self):\n        return iter(self.sampler)\n\n    def __len__(self):\n        return self.total_size\n\n\ndef pil_loader(img_str):\n    buff = io.BytesIO(img_str)\n    img = Image.open(buff)\n    return img.convert(\"RGB\")\n\n\nclass TCSLoader(object):\n    def __init__(self, conf_path, sc_config_key=\"sensecore\"):\n        print(f\"[TCSLoader] config_path: {conf_path}\")\n        print(\"--> before Client(conf_path)\")\n        self.client = Client(conf_path)\n        self.sc_config_key = sc_config_key\n        print(\"--> after Client(conf_path)\")\n\n    def __call__(self, fn, image_type=\"image\", max_num_frames=-1, min_num_frames=8, sample=\"rand\", clip=None):\n        if image_type == \"image\":\n            img_value_str = self.client.get(fn)\n            img = pil_loader(img_value_str)\n            return img\n\n        elif image_type == \"video\":\n            if fn.endswith(\"/\"):\n                frames = read_frames_folder(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, client=self.client, sample=sample)\n            elif fn.endswith(\".gif\"):\n                frames = read_frames_gif(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, client=self.client, sample=sample)\n            else:\n                frames = read_frames_decord(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, client=self.client, sample=sample, clip=clip)\n            return frames\n\n\ndef expand2square(pil_img, background_color):\n    width, height = pil_img.size\n    if width == height:\n        return pil_img\n    elif width > height:\n        result = Image.new(pil_img.mode, (width, width), background_color)\n        result.paste(pil_img, (0, (width - height) // 2))\n        return result\n    else:\n        result = Image.new(pil_img.mode, (height, height), background_color)\n        result.paste(pil_img, ((height - width) // 2, 0))\n        return result\n\n\ndef simulate_jpeg_degradation(quality):\n    def jpeg_degrade(img):\n        with io.BytesIO() as output:\n            img.convert(\"RGB\").save(output, format=\"JPEG\", quality=quality)\n            output.seek(0)  # Move the reading cursor to the start of the stream\n            img_jpeg = Image.open(output).copy()  # Use .copy() to make sure the image is loaded in memory\n        return img_jpeg\n\n    return jpeg_degrade\n\n\n# Define the JPEG compression quality range, pre-create all JPEG compression functions\nqualities = list(range(75, 101))\njpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities}\n\n\ndef build_transform(is_train, input_size, pad2square=False, normalize_type=\"imagenet\"):\n    if normalize_type == \"imagenet\":\n        MEAN, STD = IMAGENET_MEAN, IMAGENET_STD\n    elif normalize_type == \"clip\":\n        MEAN, STD = CLIP_MEAN, CLIP_STD\n    elif normalize_type == \"siglip\":\n        MEAN, STD = SIGLIP_MEAN, SIGLIP_STD\n    else:\n        raise NotImplementedError\n    if is_train:  # use data augumentation\n        transform = T.Compose(\n            [T.Lambda(lambda img: img.convert(\"RGB\") if img.mode != \"RGB\" else img), T.RandomChoice([T.Lambda(jpeg_degrade_functions[quality]) for quality in qualities]), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD)]\n        )\n    else:\n        if pad2square is False:  # now we use this transform function by default\n            transform = T.Compose([T.Lambda(lambda img: img.convert(\"RGB\") if img.mode != \"RGB\" else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD)])\n        else:\n            transform = T.Compose(\n                [T.Lambda(lambda img: img.convert(\"RGB\") if img.mode != \"RGB\" else img), T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD)]\n            )\n\n    return transform\n\n\ndef preprocess(template_name, sources, tokenizer: transformers.PreTrainedTokenizer, num_image_token_list: list, text_only: bool = False, group_by_length: bool = False, use_packed_ds: bool = False, ds_name: str = None, num_image: int = 1) -> Dict:\n    conv = get_conv_template(template_name)\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    if not text_only:\n        new_conversations = []\n        for conversation in conversations:\n            for i in range(num_image):\n                image_tokens = f\"{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}\"\n                conversation = conversation.replace(\"<image>\", image_tokens, 1)\n            new_conversations.append(conversation)\n        conversations = new_conversations\n\n    # Tokenize conversations\n    input_ids = tokenizer(\n        conversations,\n        return_tensors=\"pt\",\n        padding=False if group_by_length or use_packed_ds else \"max_length\",\n        max_length=tokenizer.model_max_length,\n        truncation=True,\n    ).input_ids\n    targets = input_ids.clone()\n\n    # assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO\n\n    # Mask targets. Only compute loss on the assistant outputs.\n    sep = conv.sep + conv.roles[1] + \": \"\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        turns = conversation.split(conv.sep2)\n        cur_len = 1\n        target[:cur_len] = IGNORE_TOKEN_ID\n        for i, turn in enumerate(turns):\n            if turn == \"\":\n                break\n            turn_len = len(tokenizer(turn).input_ids)\n\n            parts = turn.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n            # \"-2\" is hardcoded for the Llama tokenizer to make the offset correct.\n            instruction_len = len(tokenizer(parts[0]).input_ids) - 2\n\n            if i != 0 and not tokenizer.legacy:\n                # The legacy and non-legacy modes handle special tokens differently\n                instruction_len -= 1\n\n            # Ignore the user instructions\n            target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID\n            cur_len += turn_len\n\n            if i != 0 and not tokenizer.legacy:\n                # The legacy and non-legacy modes handle special tokens differently\n                cur_len -= 1\n\n        target[cur_len:] = IGNORE_TOKEN_ID\n\n        if False:  # Inspect and check the correctness of masking\n            z = target.clone()\n            z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)\n            logger.info(tokenizer.decode(z))\n            exit()\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_TOKEN_ID\n                print(f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}. #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.\")\n                sys.stdout.flush()\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n        attention_mask=input_ids.ne(tokenizer.pad_token_id),\n    )\n\n\ndef preprocess_mpt(template_name, sources, tokenizer: transformers.PreTrainedTokenizer, num_image_token_list: list, text_only: bool = False, group_by_length: bool = False, use_packed_ds: bool = False, ds_name: str = None, num_image: int = 1) -> Dict:\n    conv = get_conv_template(template_name)\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    if not text_only:\n        new_conversations = []\n        for conversation in conversations:\n            for i in range(num_image):\n                image_tokens = f\"{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}\"\n                conversation = conversation.replace(\"<image>\", image_tokens, 1)\n            new_conversations.append(conversation)\n        conversations = new_conversations\n\n    # Tokenize conversations\n    input_ids = tokenizer(\n        conversations,\n        return_tensors=\"pt\",\n        padding=False if group_by_length or use_packed_ds else \"max_length\",\n        max_length=tokenizer.model_max_length,\n        truncation=True,\n    ).input_ids\n    targets = input_ids.clone()\n\n    # Mask targets. Only compute loss on the assistant outputs.\n    sep = conv.sep + conv.roles[1]  # <|im_end|><|im_start|>assistant\\n\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())\n\n        turns = conversation.split(conv.sep)\n        re_turns = [conv.sep.join(turns[:3])]  # system + user + gpt\n        for conv_idx in range(3, len(turns), 2):\n            re_turns.append(conv.sep.join(turns[conv_idx : conv_idx + 2]))  # user + gpt\n        cur_len = 0\n        target[:cur_len] = IGNORE_TOKEN_ID\n        for i, turn in enumerate(re_turns):\n            if turn == \"\":\n                break\n            turn_len = len(tokenizer(turn).input_ids) + 1\n\n            parts = turn.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n            instruction_len = len(tokenizer(parts[0]).input_ids)\n\n            # Ignore the user instructions\n            target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID\n            # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))\n            # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))\n            # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])\n            cur_len += turn_len\n\n        target[cur_len:] = IGNORE_TOKEN_ID\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_TOKEN_ID\n                print(f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}. #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.\")\n                sys.stdout.flush()\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n        attention_mask=input_ids.ne(tokenizer.pad_token_id),\n    )\n\n\ndef preprocess_phi3(template_name, sources, tokenizer: transformers.PreTrainedTokenizer, num_image_token_list: list, text_only: bool = False, group_by_length: bool = False, use_packed_ds: bool = False, ds_name: str = None, num_image: int = 1) -> Dict:\n    conv = get_conv_template(template_name)\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    if not text_only:\n        new_conversations = []\n        for conversation in conversations:\n            for i in range(num_image):\n                image_tokens = f\"{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}\"\n                conversation = conversation.replace(\"<image>\", image_tokens, 1)\n            new_conversations.append(conversation)\n        conversations = new_conversations\n\n    # Tokenize conversations\n    tokenizer.padding_side = \"right\"\n    input_ids = tokenizer(\n        conversations,\n        return_tensors=\"pt\",\n        padding=False if group_by_length or use_packed_ds else \"max_length\",\n        max_length=tokenizer.model_max_length,\n        truncation=True,\n    ).input_ids\n    targets = input_ids.clone()\n\n    # Mask targets. Only compute loss on the assistant outputs.\n    sep = conv.sep + conv.roles[1]  # <|end|>\\n<|assistant|>\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(int(tokenizer.pad_token_id)).sum())\n\n        turns = conversation.split(conv.sep)\n        re_turns = [conv.sep.join(turns[:3])]  # system + user + gpt\n        for conv_idx in range(3, len(turns), 2):\n            re_turns.append(conv.sep.join(turns[conv_idx : conv_idx + 2]))  # user + gpt\n        cur_len = 1\n        target[:cur_len] = IGNORE_TOKEN_ID\n        endoftext_id = tokenizer.convert_tokens_to_ids(\"<|endoftext|>\")\n        target[target == endoftext_id] = IGNORE_TOKEN_ID\n\n        for i, turn in enumerate(re_turns):\n            if turn == \"\":\n                break\n            if i == 0:\n                turn_len = len(tokenizer(turn).input_ids)\n            else:\n                turn_len = len(tokenizer(turn).input_ids) - 1\n            parts = turn.split(sep)\n            if len(parts) != 2:\n                break\n            parts[0] += sep\n\n            if i == 0:\n                instruction_len = len(tokenizer(parts[0]).input_ids) - 1\n            else:\n                instruction_len = len(tokenizer(parts[0]).input_ids) - 2\n\n            # Ignore the user instructions\n            target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID\n            # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))\n            # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))\n            # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])\n            cur_len += turn_len\n\n        target[cur_len:] = IGNORE_TOKEN_ID\n\n        if False:  # Inspect and check the correctness of masking\n            z = target.clone()\n            z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)\n            print(repr(tokenizer.decode(z)))\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_TOKEN_ID\n                print(f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}. #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.\")\n                sys.stdout.flush()\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n        attention_mask=input_ids.ne(tokenizer.pad_token_id),\n    )\n\n\ndef preprocess_internlm(template_name, sources, tokenizer: transformers.PreTrainedTokenizer, num_image_token_list: list, text_only: bool = False, group_by_length: bool = False, use_packed_ds: bool = False, ds_name: str = None, num_image: int = 1) -> Dict:\n    conv = get_conv_template(template_name)\n    roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n    # Apply prompt templates\n    conversations = []\n    for i, source in enumerate(sources):\n        if roles[source[0][\"from\"]] != conv.roles[0]:\n            # Skip the first one if it is not from human\n            source = source[1:]\n\n        conv.messages = []\n        for j, sentence in enumerate(source):\n            role = roles[sentence[\"from\"]]\n            assert role == conv.roles[j % 2], f\"{i}\"\n            sentence[\"value\"] = sentence[\"value\"].strip()\n            conv.append_message(role, sentence[\"value\"])\n        conversations.append(conv.get_prompt())\n\n    if not text_only:\n        new_conversations = []\n        for conversation in conversations:\n            for i in range(num_image):\n                image_tokens = f\"{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}\"\n                conversation = conversation.replace(\"<image>\", image_tokens, 1)\n            new_conversations.append(conversation)\n        conversations = new_conversations\n\n    # Tokenize conversations\n    input_ids = tokenizer(\n        conversations,\n        return_tensors=\"pt\",\n        padding=False if group_by_length or use_packed_ds else \"max_length\",\n        max_length=tokenizer.model_max_length,\n        truncation=True,\n    ).input_ids\n    targets = input_ids.clone()\n\n    for conversation, target in zip(conversations, targets):\n        total_len = int(target.ne(tokenizer.pad_token_id).sum())  # 浦语里面 pad_token_id = eos_token_id\n        cur_len = 1\n        target[:cur_len] = IGNORE_TOKEN_ID  # <s>\n        parts = conversation.split(conv.roles[1])  # [UNUSED_TOKEN_146]assistant\\n\n        info = parts[0] + conv.roles[1]\n        temp_len = len(tokenizer(info).input_ids) - 1  # 去除tokenizer的<s>\n        target[cur_len : cur_len + temp_len] = IGNORE_TOKEN_ID\n        cur_len = cur_len + temp_len\n\n        for index in range(1, len(parts) - 1):\n            info = parts[index]\n            part1, part2 = info.split(conv.roles[0])\n            temp_len = len(tokenizer(part1).input_ids) - 1\n            cur_len = cur_len + temp_len\n            part = conv.roles[0] + part2 + conv.roles[1]\n            temp_len = len(tokenizer(part).input_ids) - 1\n            target[cur_len : cur_len + temp_len] = IGNORE_TOKEN_ID\n            cur_len = cur_len + temp_len\n        last_info = parts[-1]\n        temp_len = len(tokenizer(last_info).input_ids) - 1\n        cur_len = cur_len + temp_len\n\n        target[cur_len:] = IGNORE_TOKEN_ID\n        if False:  # Inspect and check the correctness of masking\n            z = target.clone()\n            z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)\n            print(repr(tokenizer.decode(z)))\n\n        if cur_len < tokenizer.model_max_length:\n            if cur_len != total_len:\n                target[:] = IGNORE_TOKEN_ID\n                print(f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.\")\n                sys.stdout.flush()\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n        attention_mask=input_ids.ne(tokenizer.pad_token_id),\n    )\n\n\ndef preprocess_internvl2_5(template_name, sources, tokenizer: transformers.PreTrainedTokenizer, num_image_token_list: list, text_only: bool = False, group_by_length: bool = False, use_packed_ds: bool = False, ds_name: str = None, num_image: int = 1) -> Dict:\n    assert len(sources) == 1, \"process only the first conversations\"\n    conversations = sources[0]\n\n    if conversations[0][\"from\"] == \"system\":\n        system_prompt = conversations[0][\"value\"]\n        conversations = conversations[1:]  # remove system prompt\n    else:\n        conv = get_conv_template(template_name)\n        system_prompt = conv.system_message\n        # system_prompt = None\n\n    if not text_only:\n        new_conversations = []\n        current_image_idx = 0\n        for conversation in conversations:\n            if conversation[\"from\"] == \"human\":\n                image_cnt = conversation[\"value\"].count(\"<image>\")\n                for i in range(image_cnt):\n                    if current_image_idx == num_image:\n                        break\n                    image_tokens = f\"{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[current_image_idx]}{IMG_END_TOKEN}\"\n                    conversation[\"value\"] = conversation[\"value\"].replace(\"<image>\", image_tokens, 1)\n                    current_image_idx += 1\n            new_conversations.append(conversation)\n        conversations = new_conversations\n        assert current_image_idx == num_image, f\"{current_image_idx} != {num_image}\"\n\n    batches, roles = [], []\n    if system_prompt is not None:\n        batches.append(f\"<|im_start|>system\\n{system_prompt}<|im_end|>\\n\")\n        roles.append(\"system\")\n    for conversation in conversations:\n        if conversation[\"from\"] == \"human\":\n            batches.append(f\"<|im_start|>user\\n{conversation['value']}<|im_end|>\\n\")\n            roles.append(\"human\")\n        elif conversation[\"from\"] == \"gpt\":\n            batches.append(f\"<|im_start|>assistant\\n{conversation['value']}<|im_end|>\\n\")\n            roles.append(\"gpt\")\n        else:\n            raise NotImplementedError\n\n    add_bos_token = getattr(tokenizer, \"add_bos_token\", False)\n    if add_bos_token:  # for InternLM series\n        batches[0] = tokenizer.bos_token + batches[0]\n\n    # Tokenize conversations\n    input_ids = tokenizer(\n        batches,\n        return_tensors=\"np\",\n        padding=False,\n        max_length=tokenizer.model_max_length,\n        truncation=False,\n    ).input_ids\n\n    if add_bos_token:  # for InternLM series\n        input_ids = [item[1:] for item in input_ids]\n\n    final_input_ids, final_targets = [], []\n    ignore_ids = tokenizer(\"<|im_start|>assistant\\n\", return_tensors=\"np\").input_ids[0]\n    ignore_len = ignore_ids.shape[0] - 1 if add_bos_token else ignore_ids.shape[0]\n    for role, input_id in zip(roles, input_ids):\n        final_input_ids.append(input_id)\n        if role == \"system\" or role == \"human\":\n            final_targets.append(np.full(input_id.shape, IGNORE_TOKEN_ID))  # ignore\n        elif role == \"gpt\":\n            target = input_id.copy()\n            target[:ignore_len] = IGNORE_TOKEN_ID  # ignore loss for `<|im_start|>assistant\\n`\n            target[-1:] = IGNORE_TOKEN_ID  # ignore loss for `\\n`\n            final_targets.append(target)\n        else:\n            raise NotImplementedError\n    input_ids = torch.tensor(np.concatenate(final_input_ids))[: tokenizer.model_max_length]\n    targets = torch.tensor(np.concatenate(final_targets))[: tokenizer.model_max_length]\n\n    padding = False if group_by_length or use_packed_ds else True\n    if padding:\n        current_length = input_ids.size(0)\n        padding_length = tokenizer.model_max_length - current_length\n        input_ids = F.pad(input_ids, (0, padding_length), value=tokenizer.pad_token_id)\n        targets = F.pad(targets, (0, padding_length), value=IGNORE_TOKEN_ID)\n\n    input_ids = input_ids.unsqueeze(0)\n    targets = targets.unsqueeze(0)\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n        attention_mask=input_ids.ne(tokenizer.pad_token_id),\n    )\n\n\ndef find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):\n    best_ratio_diff = float(\"inf\")\n    best_ratio = (1, 1)\n    area = width * height\n    for ratio in target_ratios:\n        target_aspect_ratio = ratio[0] / ratio[1]\n        ratio_diff = abs(aspect_ratio - target_aspect_ratio)\n        if ratio_diff < best_ratio_diff:\n            best_ratio_diff = ratio_diff\n            best_ratio = ratio\n        elif ratio_diff == best_ratio_diff:\n            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:\n                best_ratio = ratio\n    # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')\n    return best_ratio\n\n\ndef dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):\n    orig_width, orig_height = image.size\n    aspect_ratio = orig_width / orig_height\n\n    # calculate the existing image aspect ratio\n    target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num)\n    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])\n\n    # find the closest aspect ratio to the target\n    target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)\n\n    # calculate the target width and height\n    target_width = image_size * target_aspect_ratio[0]\n    target_height = image_size * target_aspect_ratio[1]\n    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]\n\n    # resize the image\n    resized_img = image.resize((target_width, target_height))\n    processed_images = []\n    for i in range(blocks):\n        box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size)\n        # split the image\n        split_img = resized_img.crop(box)\n        processed_images.append(split_img)\n    assert len(processed_images) == blocks\n    if use_thumbnail and len(processed_images) != 1:\n        thumbnail_img = image.resize((image_size, image_size))\n        processed_images.append(thumbnail_img)\n    return processed_images\n\n\ndef preprocess_internvl2_5_siirl(sources, tokenizer: transformers.PreTrainedTokenizer, num_image_token_list: list, max_prompt_length: int, group_by_length: bool = False, use_packed_ds: bool = False, ds_name: str = None, num_image: int = 1, left_pad: bool = False) -> Dict:\n    if sources[0][\"role\"] == \"system\":\n        system_prompt = sources[0][\"content\"]\n        conversations = sources[1:]  # remove system prompt\n    else:\n        conv = get_conv_template(\"internvl2_5\")\n        system_prompt = conv.system_message\n        conversations = sources\n\n    for conversation in conversations:\n        current_image_idx = 0\n        if conversation[\"role\"] == \"user\":\n            image_cnt = conversation[\"content\"].count(\"<image>\")\n            for i in range(image_cnt):\n                if current_image_idx == num_image:\n                    break\n                image_tokens = f\"{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[current_image_idx]}{IMG_END_TOKEN}\"\n                conversation[\"content\"] = conversation[\"content\"].replace(\"<image>\", image_tokens, 1)\n                current_image_idx += 1\n            assert current_image_idx == num_image, f\"{current_image_idx} != {num_image}\"\n\n    batches, roles = [], []\n    if system_prompt is not None:\n        batches.append(f\"<|im_start|>system\\n{system_prompt}<|im_end|>\\n\")\n        roles.append(\"system\")\n\n    for conversation in conversations:\n        if conversation[\"role\"] == \"user\":\n            batches.append(f\"<|im_start|>user\\n{conversation['content']}<|im_end|>\\n\")\n            roles.append(\"human\")\n        elif conversation[\"role\"] == \"gpt\":\n            batches.append(f\"<|im_start|>assistant\\n{conversation['content']}<|im_end|>\\n\")\n            roles.append(\"gpt\")\n        else:\n            raise NotImplementedError\n\n    add_bos_token = getattr(tokenizer, \"add_bos_token\", False)\n    if add_bos_token:  # for InternLM series\n        batches[0] = tokenizer.bos_token + batches[0]\n\n    # Tokenize conversation\n    input_ids = tokenizer(\n        batches,\n        return_tensors=\"np\",\n        padding=False,\n        max_length=max_prompt_length,\n        truncation=False,\n    ).input_ids\n\n    if add_bos_token:  # for InternLM series\n        input_ids = [item[1:] for item in input_ids]\n\n    final_input_ids, final_targets = [], []\n    ignore_ids = tokenizer(\"<|im_start|>assistant\\n\", return_tensors=\"np\").input_ids[0]\n    ignore_len = ignore_ids.shape[0] - 1 if add_bos_token else ignore_ids.shape[0]\n    for role, input_id in zip(roles, input_ids):\n        final_input_ids.append(input_id)\n        if role == \"system\" or role == \"human\":\n            final_targets.append(np.full(input_id.shape, IGNORE_TOKEN_ID))  # ignore\n        elif role == \"gpt\":\n            target = input_id.copy()\n            target[:ignore_len] = IGNORE_TOKEN_ID  # ignore loss for `<|im_start|>assistant\\n`\n            target[-1:] = IGNORE_TOKEN_ID  # ignore loss for `\\n`\n            final_targets.append(target)\n        else:\n            raise NotImplementedError\n    input_ids = torch.tensor(np.concatenate(final_input_ids))[:max_prompt_length]\n    targets = torch.tensor(np.concatenate(final_targets))[:max_prompt_length]\n\n    padding = False if group_by_length or use_packed_ds else True\n    if padding:\n        current_length = input_ids.size(0)\n        padding_length = max_prompt_length - current_length\n        real_pad = (padding_length, 0) if left_pad else (0, padding_length)\n        input_ids = F.pad(input_ids, real_pad, value=tokenizer.pad_token_id)\n        targets = F.pad(targets, real_pad, value=IGNORE_TOKEN_ID)\n\n    input_ids = input_ids.unsqueeze(0)\n    targets = targets.unsqueeze(0)\n\n    return dict(\n        input_ids=input_ids,\n        labels=targets,\n        attention_mask=input_ids.ne(tokenizer.pad_token_id),\n    )\n\n\ndef find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):\n    best_ratio_diff = float(\"inf\")\n    best_ratio = (1, 1)\n    area = width * height\n    for ratio in target_ratios:\n        target_aspect_ratio = ratio[0] / ratio[1]\n        ratio_diff = abs(aspect_ratio - target_aspect_ratio)\n        if ratio_diff < best_ratio_diff:\n            best_ratio_diff = ratio_diff\n            best_ratio = ratio\n        elif ratio_diff == best_ratio_diff:\n            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:\n                best_ratio = ratio\n    # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')\n    return best_ratio\n\n\ndef dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):\n    orig_width, orig_height = image.size\n    aspect_ratio = orig_width / orig_height\n\n    # calculate the existing image aspect ratio\n    target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num)\n    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])\n\n    # find the closest aspect ratio to the target\n    target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)\n\n    # calculate the target width and height\n    target_width = image_size * target_aspect_ratio[0]\n    target_height = image_size * target_aspect_ratio[1]\n    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]\n\n    # resize the image\n    resized_img = image.resize((target_width, target_height))\n    processed_images = []\n    for i in range(blocks):\n        box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size)\n        # split the image\n        split_img = resized_img.crop(box)\n        processed_images.append(split_img)\n    assert len(processed_images) == blocks\n    if use_thumbnail and len(processed_images) != 1:\n        thumbnail_img = image.resize((image_size, image_size))\n        processed_images.append(thumbnail_img)\n    return processed_images\n"
  },
  {
    "path": "siirl/models/transformers/internvl_chat/__init__.py",
    "content": "from .configuration_intern_vit import InternVisionConfig\nfrom .configuration_internvl_chat import InternVLChatConfig\nfrom .modeling_intern_vit import InternVisionModel\nfrom .modeling_internvl_chat import InternVLChatModel\n\n__all__ = [\"InternVisionConfig\", \"InternVisionModel\", \"InternVLChatConfig\", \"InternVLChatModel\"]\n"
  },
  {
    "path": "siirl/models/transformers/internvl_chat/configuration_intern_vit.py",
    "content": "# --------------------------------------------------------\n# InternVL\n# Copyright (c) 2024 OpenGVLab\n# Licensed under The MIT License [see LICENSE for details]\n# --------------------------------------------------------\n\nimport os\nfrom typing import Union\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\n\nclass InternVisionConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to\n    instantiate a vision encoder according to the specified arguments, defining the model architecture.\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n    Args:\n        num_channels (`int`, *optional*, defaults to 3):\n            Number of color channels in the input images (e.g., 3 for RGB).\n        patch_size (`int`, *optional*, defaults to 14):\n            The size (resolution) of each patch.\n        image_size (`int`, *optional*, defaults to 224):\n            The size (resolution) of each image.\n        qkv_bias (`bool`, *optional*, defaults to `False`):\n            Whether to add a bias to the queries and values in the self-attention layers.\n        hidden_size (`int`, *optional*, defaults to 3200):\n            Dimensionality of the encoder layers and the pooler layer.\n        num_attention_heads (`int`, *optional*, defaults to 25):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        intermediate_size (`int`, *optional*, defaults to 12800):\n            Dimensionality of the \"intermediate\" (i.e., feed-forward) layer in the Transformer encoder.\n        qk_normalization (`bool`, *optional*, defaults to `True`):\n            Whether to normalize the queries and keys in the self-attention layers.\n        num_hidden_layers (`int`, *optional*, defaults to 48):\n            Number of hidden layers in the Transformer encoder.\n        use_flash_attn (`bool`, *optional*, defaults to `True`):\n            Whether to use flash attention mechanism.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"gelu\"`):\n            The non-linear activation function (function or string) in the encoder and pooler. If string, `\"gelu\"`,\n            `\"relu\"`, `\"selu\"` and `\"gelu_new\"` ``\"gelu\"` are supported.\n        layer_norm_eps (`float`, *optional*, defaults to 1e-6):\n            The epsilon used by the layer normalization layers.\n        dropout (`float`, *optional*, defaults to 0.0):\n            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.\n        drop_path_rate (`float`, *optional*, defaults to 0.0):\n            Dropout rate for stochastic depth.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        initializer_factor (`float`, *optional*, defaults to 0.1):\n            A factor for layer scale.\n    \"\"\"\n\n    model_type = \"intern_vit_6b\"\n\n    def __init__(\n        self,\n        num_channels=3,\n        patch_size=14,\n        image_size=224,\n        qkv_bias=False,\n        hidden_size=3200,\n        num_attention_heads=25,\n        intermediate_size=12800,\n        qk_normalization=True,\n        num_hidden_layers=48,\n        use_flash_attn=True,\n        hidden_act=\"gelu\",\n        norm_type=\"rms_norm\",\n        layer_norm_eps=1e-6,\n        dropout=0.0,\n        drop_path_rate=0.0,\n        attention_dropout=0.0,\n        initializer_range=0.02,\n        initializer_factor=0.1,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.dropout = dropout\n        self.drop_path_rate = drop_path_rate\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.num_channels = num_channels\n        self.patch_size = patch_size\n        self.image_size = image_size\n        self.initializer_range = initializer_range\n        self.initializer_factor = initializer_factor\n        self.attention_dropout = attention_dropout\n        self.layer_norm_eps = layer_norm_eps\n        self.hidden_act = hidden_act\n        self.norm_type = norm_type\n        self.qkv_bias = qkv_bias\n        self.qk_normalization = qk_normalization\n        self.use_flash_attn = use_flash_attn\n\n    @classmethod\n    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> \"PretrainedConfig\":\n        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)\n\n        if \"vision_config\" in config_dict:\n            config_dict = config_dict[\"vision_config\"]\n\n        if \"model_type\" in config_dict and hasattr(cls, \"model_type\") and config_dict[\"model_type\"] != cls.model_type:\n            logger.warning(f\"You are using a model of type {config_dict['model_type']} to instantiate a model of type {cls.model_type}. This is not supported for all configurations of models and can yield errors.\")\n\n        return cls.from_dict(config_dict, **kwargs)\n"
  },
  {
    "path": "siirl/models/transformers/internvl_chat/configuration_internlm2.py",
    "content": "# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on transformers/src/transformers/models/llama/configuration_llama.py\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"InternLM2 model configuration\"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\nINTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}\n\n\n# Modified from transformers.model.llama.configuration_llama.LlamaConfig\nclass InternLM2Config(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate\n    an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a\n    configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 32000):\n            Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`InternLM2Model`]\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 11008):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        num_key_value_heads (`int`, *optional*):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to\n            `num_attention_heads`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 2048):\n            The maximum sequence length that this model might ever be used with. Typically set this to something large\n            just in case (e.g., 512 or 1024 or 2048).\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-12):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        tie_word_embeddings(`bool`, *optional*, defaults to `False`):\n            Whether to tie weight embeddings\n        Example:\n\n    \"\"\"\n\n    model_type = \"internlm2\"\n    _auto_class = \"AutoConfig\"\n\n    def __init__(  # pylint: disable=W0102\n        self,\n        vocab_size=103168,\n        hidden_size=4096,\n        intermediate_size=11008,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=None,\n        hidden_act=\"silu\",\n        max_position_embeddings=2048,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=0,\n        bos_token_id=1,\n        eos_token_id=2,\n        tie_word_embeddings=False,\n        bias=True,\n        rope_theta=10000,\n        rope_scaling=None,\n        attn_implementation=\"eager\",\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.bias = bias\n\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n        self.num_key_value_heads = num_key_value_heads\n\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self._rope_scaling_validation()\n\n        self.attn_implementation = attn_implementation\n        if self.attn_implementation is None:\n            self.attn_implementation = \"eager\"\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n    def _rope_scaling_validation(self):\n        \"\"\"\n        Validate the `rope_scaling` configuration.\n        \"\"\"\n        if self.rope_scaling is None:\n            return\n\n        if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:\n            raise ValueError(f\"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, got {self.rope_scaling}\")\n        rope_scaling_type = self.rope_scaling.get(\"type\", None)\n        rope_scaling_factor = self.rope_scaling.get(\"factor\", None)\n        if rope_scaling_type is None or rope_scaling_type not in [\"linear\", \"dynamic\"]:\n            raise ValueError(f\"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}\")\n        if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:\n            raise ValueError(f\"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}\")\n"
  },
  {
    "path": "siirl/models/transformers/internvl_chat/configuration_internvl_chat.py",
    "content": "# --------------------------------------------------------\n# InternVL\n# Copyright (c) 2024 OpenGVLab\n# Licensed under The MIT License [see LICENSE for details]\n# --------------------------------------------------------\n\nimport copy\n\nfrom transformers import AutoConfig, LlamaConfig\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nfrom .configuration_intern_vit import InternVisionConfig\nfrom .configuration_internlm2 import InternLM2Config\n\nlogger = logging.get_logger(__name__)\n\n\nclass InternVLChatConfig(PretrainedConfig):\n    model_type = \"internvl_chat\"\n    is_composition = True\n\n    def __init__(self, vision_config=None, llm_config=None, use_backbone_lora=0, use_llm_lora=0, select_layer=-1, force_image_size=None, downsample_ratio=0.5, template=None, dynamic_image_size=False, use_thumbnail=False, ps_version=\"v1\", min_dynamic_patch=1, max_dynamic_patch=6, **kwargs):\n        super().__init__(**kwargs)\n\n        if vision_config is None:\n            vision_config = {\"architectures\": [\"InternVisionModel\"]}\n            logger.info(\"vision_config is None. Initializing the InternVisionConfig with default values.\")\n\n        if llm_config is None:\n            llm_config = {\"architectures\": [\"InternLM2ForCausalLM\"]}\n            logger.info(\"llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).\")\n\n        self.vision_config = InternVisionConfig(**vision_config)\n        if llm_config.get(\"architectures\")[0] == \"LlamaForCausalLM\":\n            self.llm_config = LlamaConfig(**llm_config)\n        elif llm_config.get(\"architectures\")[0] == \"InternLM2ForCausalLM\":\n            self.llm_config = InternLM2Config(**llm_config)\n        else:\n            raise ValueError(\"Unsupported architecture: {}\".format(llm_config.get(\"architectures\")[0]))\n        self.use_backbone_lora = use_backbone_lora\n        self.use_llm_lora = use_llm_lora\n        self.select_layer = select_layer\n        self.force_image_size = force_image_size\n        self.downsample_ratio = downsample_ratio\n        self.template = template\n        self.dynamic_image_size = dynamic_image_size\n        self.use_thumbnail = use_thumbnail\n        self.ps_version = ps_version  # pixel shuffle version\n        self.min_dynamic_patch = min_dynamic_patch\n        self.max_dynamic_patch = max_dynamic_patch\n\n        logger.info(f\"vision_select_layer: {self.select_layer}\")\n        logger.info(f\"ps_version: {self.ps_version}\")\n        logger.info(f\"min_dynamic_patch: {self.min_dynamic_patch}\")\n        logger.info(f\"max_dynamic_patch: {self.max_dynamic_patch}\")\n\n    def to_dict(self):\n        \"\"\"\n        Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].\n\n        Returns:\n            `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,\n        \"\"\"\n        output = copy.deepcopy(self.__dict__)\n        output[\"vision_config\"] = self.vision_config.to_dict()\n        output[\"llm_config\"] = self.llm_config.to_dict()\n        output[\"model_type\"] = self.__class__.model_type\n        output[\"use_backbone_lora\"] = self.use_backbone_lora\n        output[\"use_llm_lora\"] = self.use_llm_lora\n        output[\"select_layer\"] = self.select_layer\n        output[\"force_image_size\"] = self.force_image_size\n        output[\"downsample_ratio\"] = self.downsample_ratio\n        output[\"template\"] = self.template\n        output[\"dynamic_image_size\"] = self.dynamic_image_size\n        output[\"use_thumbnail\"] = self.use_thumbnail\n        output[\"ps_version\"] = self.ps_version\n        output[\"min_dynamic_patch\"] = self.min_dynamic_patch\n        output[\"max_dynamic_patch\"] = self.max_dynamic_patch\n\n        return output\n"
  },
  {
    "path": "siirl/models/transformers/internvl_chat/modeling_intern_vit.py",
    "content": "# --------------------------------------------------------\n# InternVL\n# Copyright (c) 2024 OpenGVLab\n# Licensed under The MIT License [see LICENSE for details]\n# --------------------------------------------------------\n\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom einops import rearrange\nfrom timm.models.layers import DropPath\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.utils import logging\n\nfrom .configuration_intern_vit import InternVisionConfig\n\ntry:\n    from flash_attn.bert_padding import pad_input, unpad_input\n    from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func\n\n    has_flash_attn = True\nexcept:\n    print(\"FlashAttention2 is not installed.\")\n    has_flash_attn = False\n\nlogger = logging.get_logger(__name__)\n\n\nclass FlashAttention(nn.Module):\n    \"\"\"Implement the scaled dot product attention with softmax.\n    Arguments\n    ---------\n        softmax_scale: The temperature to use for the softmax attention.\n                      (default: 1/sqrt(d_keys) where d_keys is computed at\n                      runtime)\n        attention_dropout: The dropout rate to apply to the attention\n                           (default: 0.0)\n    \"\"\"\n\n    def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):\n        super().__init__()\n        self.softmax_scale = softmax_scale\n        self.dropout_p = attention_dropout\n\n    def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, max_s=None, need_weights=False):\n        \"\"\"Implements the multihead softmax attention.\n        Arguments\n        ---------\n            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None\n                if unpadded: (nnz, 3, h, d)\n            key_padding_mask: a bool tensor of shape (B, S)\n        \"\"\"\n        assert not need_weights\n        assert qkv.dtype in [torch.float16, torch.bfloat16]\n        assert qkv.is_cuda\n\n        if cu_seqlens is None:\n            batch_size = qkv.shape[0]\n            seqlen = qkv.shape[1]\n            if key_padding_mask is None:\n                qkv = rearrange(qkv, \"b s ... -> (b s) ...\")\n                max_s = seqlen\n                cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device)\n                output = flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal)\n                output = rearrange(output, \"(b s) ... -> b s ...\", b=batch_size)\n            else:\n                nheads = qkv.shape[-2]\n                x = rearrange(qkv, \"b s three h d -> b s (three h d)\")\n                x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)\n                x_unpad = rearrange(x_unpad, \"nnz (three h d) -> nnz three h d\", three=3, h=nheads)\n                output_unpad = flash_attn_varlen_qkvpacked_func(x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal)\n                output = rearrange(pad_input(rearrange(output_unpad, \"nnz h d -> nnz (h d)\"), indices, batch_size, seqlen), \"b s (h d) -> b s h d\", h=nheads)\n        else:\n            assert max_s is not None\n            output = flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal)\n\n        return output, None\n\n\nclass InternRMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\ntry:\n    from apex.normalization import FusedRMSNorm\n\n    InternRMSNorm = FusedRMSNorm  # noqa\n\n    logger.info(\"Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm\")\nexcept ImportError:\n    # using the normal InternRMSNorm\n    pass\nexcept Exception:\n    logger.warning(\"discovered apex but it failed to load, falling back to InternRMSNorm\")\n    pass\n\n\nNORM2FN = {\n    \"rms_norm\": InternRMSNorm,\n    \"layer_norm\": nn.LayerNorm,\n}\n\n\nclass InternVisionEmbeddings(nn.Module):\n    def __init__(self, config: InternVisionConfig):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n\n        self.class_embedding = nn.Parameter(\n            torch.randn(1, 1, self.embed_dim),\n        )\n\n        self.patch_embedding = nn.Conv2d(in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size)\n\n        self.num_patches = (self.image_size // self.patch_size) ** 2\n        self.num_positions = self.num_patches + 1\n\n        self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))\n\n    def _get_pos_embed(self, pos_embed, H, W):\n        target_dtype = pos_embed.dtype\n        pos_embed = pos_embed.float().reshape(1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)\n        pos_embed = F.interpolate(pos_embed, size=(H, W), mode=\"bicubic\", align_corners=False).reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)\n        return pos_embed\n\n    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:\n        target_dtype = self.patch_embedding.weight.dtype\n        patch_embeds = self.patch_embedding(pixel_values)  # shape = [*, channel, width, height]\n        batch_size, _, height, width = patch_embeds.shape\n        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)\n        class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)\n        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)\n        position_embedding = torch.cat([self.position_embedding[:, :1, :], self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)], dim=1)\n        embeddings = embeddings + position_embedding.to(target_dtype)\n        return embeddings\n\n\nclass InternAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: InternVisionConfig):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.use_flash_attn = config.use_flash_attn and has_flash_attn\n        if config.use_flash_attn and not has_flash_attn:\n            print(\"Warning: Flash Attention is not available, use_flash_attn is set to False.\")\n        self.head_dim = self.embed_dim // self.num_heads\n        if self.head_dim * self.num_heads != self.embed_dim:\n            raise ValueError(f\"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads}).\")\n\n        self.scale = self.head_dim**-0.5\n        self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)\n        self.attn_drop = nn.Dropout(config.attention_dropout)\n        self.proj_drop = nn.Dropout(config.dropout)\n\n        self.qk_normalization = config.qk_normalization\n\n        if self.qk_normalization:\n            self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)\n            self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)\n\n        if self.use_flash_attn:\n            self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)\n        self.proj = nn.Linear(self.embed_dim, self.embed_dim)\n\n    def _naive_attn(self, x):\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)\n\n        if self.qk_normalization:\n            B_, H_, N_, D_ = q.shape\n            q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)\n            k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)\n\n        attn = (q * self.scale) @ k.transpose(-2, -1)\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n    def _flash_attn(self, x, key_padding_mask=None, need_weights=False):\n        qkv = self.qkv(x)\n        qkv = rearrange(qkv, \"b s (three h d) -> b s three h d\", three=3, h=self.num_heads)\n\n        if self.qk_normalization:\n            q, k, v = qkv.unbind(2)\n            q = self.q_norm(q.flatten(-2, -1)).view(q.shape)\n            k = self.k_norm(k.flatten(-2, -1)).view(k.shape)\n            qkv = torch.stack([q, k, v], dim=2)\n\n        context, _ = self.inner_attn(qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False)\n        outs = self.proj(rearrange(context, \"b s h d -> b s (h d)\"))\n        outs = self.proj_drop(outs)\n        return outs\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)\n        return x\n\n\nclass InternMLP(nn.Module):\n    def __init__(self, config: InternVisionConfig):\n        super().__init__()\n        self.config = config\n        self.act = ACT2FN[config.hidden_act]\n        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)\n        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass InternVisionEncoderLayer(nn.Module):\n    def __init__(self, config: InternVisionConfig, drop_path_rate: float):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.norm_type = config.norm_type\n\n        self.attn = InternAttention(config)\n        self.mlp = InternMLP(config)\n        self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)\n        self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)\n\n        self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))\n        self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))\n        self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()\n        self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n    ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`\n        \"\"\"\n        hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1)\n\n        hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2)\n\n        return hidden_states\n\n\nclass InternVisionEncoder(nn.Module):\n    \"\"\"\n    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a\n    [`InternEncoderLayer`].\n\n    Args:\n        config (`InternConfig`):\n            The corresponding vision configuration for the `InternEncoder`.\n    \"\"\"\n\n    def __init__(self, config: InternVisionConfig):\n        super().__init__()\n        self.config = config\n        # stochastic depth decay rule\n        dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]\n        self.layers = nn.ModuleList([InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])\n        self.gradient_checkpointing = True\n\n    def forward(\n        self,\n        inputs_embeds,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, BaseModelOutput]:\n        r\"\"\"\n        Args:\n            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):\n                Embedded representation of the inputs. Should be float, not int tokens.\n            output_hidden_states (`bool`, *optional*):\n                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n                for more detail.\n            return_dict (`bool`, *optional*):\n                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        \"\"\"\n        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        encoder_states = () if output_hidden_states else None\n        hidden_states = inputs_embeds\n\n        for idx, encoder_layer in enumerate(self.layers):\n            if output_hidden_states:\n                encoder_states = encoder_states + (hidden_states,)\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = torch.utils.checkpoint.checkpoint(encoder_layer, hidden_states)\n            else:\n                layer_outputs = encoder_layer(\n                    hidden_states,\n                )\n            hidden_states = layer_outputs\n\n        if output_hidden_states:\n            encoder_states = encoder_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(v for v in [hidden_states, encoder_states] if v is not None)\n        return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states)\n\n\nclass InternVisionModel(PreTrainedModel):\n    main_input_name = \"pixel_values\"\n    _supports_flash_attn_2 = True\n    config_class = InternVisionConfig\n    _no_split_modules = [\"InternVisionEncoderLayer\"]\n\n    def __init__(self, config: InternVisionConfig):\n        super().__init__(config)\n        self.config = config\n\n        self.embeddings = InternVisionEmbeddings(config)\n        self.encoder = InternVisionEncoder(config)\n\n    def resize_pos_embeddings(self, old_size, new_size, patch_size):\n        pos_emb = self.embeddings.position_embedding\n        _, num_positions, embed_dim = pos_emb.shape\n        cls_emb = pos_emb[:, :1, :]\n        pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)\n        pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode=\"bicubic\", align_corners=False)\n        pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)\n        pos_emb = torch.cat([cls_emb, pos_emb], dim=1)\n        self.embeddings.position_embedding = nn.Parameter(pos_emb)\n        self.embeddings.image_size = new_size\n        logger.info(\"Resized position embeddings from {} to {}\".format(old_size, new_size))\n\n    def get_input_embeddings(self):\n        return self.embeddings\n\n    def forward(\n        self,\n        pixel_values: Optional[torch.FloatTensor] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        pixel_embeds: Optional[torch.FloatTensor] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPooling]:\n        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if pixel_values is None and pixel_embeds is None:\n            raise ValueError(\"You have to specify pixel_values or pixel_embeds\")\n\n        if pixel_embeds is not None:\n            hidden_states = pixel_embeds\n        else:\n            if len(pixel_values.shape) == 4:\n                hidden_states = self.embeddings(pixel_values)\n            else:\n                raise ValueError(f\"wrong pixel_values size: {pixel_values.shape}\")\n        encoder_outputs = self.encoder(\n            inputs_embeds=hidden_states,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        last_hidden_state = encoder_outputs.last_hidden_state\n        pooled_output = last_hidden_state[:, 0, :]\n\n        if not return_dict:\n            return (last_hidden_state, pooled_output) + encoder_outputs[1:]\n\n        return BaseModelOutputWithPooling(\n            last_hidden_state=last_hidden_state,\n            pooler_output=pooled_output,\n            hidden_states=encoder_outputs.hidden_states,\n            attentions=encoder_outputs.attentions,\n        )\n"
  },
  {
    "path": "siirl/models/transformers/internvl_chat/modeling_internlm2.py",
    "content": "# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on transformers/src/transformers/models/llama/modeling_llama.py\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch InternLM2 model.\"\"\"\n\nimport math\nimport queue\nimport threading\nimport warnings\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom einops import rearrange\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings\n\ntry:\n    from transformers.generation.streamers import BaseStreamer\nexcept:  # noqa # pylint: disable=bare-except\n    BaseStreamer = None\n\nfrom .configuration_internlm2 import InternLM2Config\n\nlogger = logging.get_logger(__name__)\n\n_CONFIG_FOR_DOC = \"InternLM2Config\"\n\nflash_attn_func, flash_attn_varlen_func = None, None\npad_input, index_first_axis, unpad_input = None, None, None\ntry:\n    from flash_attn import flash_attn_func as _flash_attn_func\n    from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func\n    from flash_attn.bert_padding import index_first_axis as _index_first_axis\n    from flash_attn.bert_padding import pad_input as _pad_input\n    from flash_attn.bert_padding import unpad_input as _unpad_input\n\n    flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func\n    pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input\n    has_flash_attn = True\nexcept:\n    has_flash_attn = False\n\n\ndef _import_flash_attn():\n    global flash_attn_func, flash_attn_varlen_func\n    global pad_input, index_first_axis, unpad_input\n    try:\n        from flash_attn import flash_attn_func as _flash_attn_func\n        from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func\n        from flash_attn.bert_padding import index_first_axis as _index_first_axis\n        from flash_attn.bert_padding import pad_input as _pad_input\n        from flash_attn.bert_padding import unpad_input as _unpad_input\n\n        flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func\n        pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input\n    except ImportError:\n        raise ImportError(\"flash_attn is not installed.\")\n\n\n# Copied from transformers.models.llama.modeling_llama._get_unpad_data\ndef _get_unpad_data(attention_mask):\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))\n    return (\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n    )\n\n\n# Copied from transformers.models.bart.modeling_bart._make_causal_mask\ndef _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n\n    if past_key_values_length > 0:\n        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)\n    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)\n\n\n# Copied from transformers.models.bart.modeling_bart._expand_mask\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2\nclass InternLM2RMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        InternLM2RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\n# Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2\nclass InternLM2RotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        # Build here to make `torch.jit.trace` work.\n        self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype())\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype)\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n    def forward(self, x, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        if seq_len > self.max_seq_len_cached:\n            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)\n\n        return (\n            self.cos_cached[:seq_len].to(dtype=x.dtype),\n            self.sin_cached[:seq_len].to(dtype=x.dtype),\n        )\n\n\n# Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2\nclass InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):\n    \"\"\"InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev\"\"\"\n\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype)\n        t = t / self.scaling_factor\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\n# Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2\nclass InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):\n    \"\"\"InternLM2RotaryEmbedding extended with Dynamic NTK scaling.\n    Credits to the Reddit users /u/bloc97 and /u/emozilla.\n    \"\"\"\n\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n\n        if seq_len > self.max_position_embeddings:\n            base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))\n            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n            self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype)\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\n# Copied from transformers.model.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\n# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\"\"\"\n    cos = cos[position_ids].unsqueeze(unsqueeze_dim)\n    sin = sin[position_ids].unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\nclass InternLM2MLP(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x))\n\n        return down_proj\n\n\n# Copied from transformers.model.llama.modeling_llama.repeat_kv\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\n# Modified from transformers.model.llama.modeling_llama.LlamaAttention\nclass InternLM2Attention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: InternLM2Config):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.max_position_embeddings = config.max_position_embeddings\n        self.is_causal = True\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads}).\")\n\n        self.wqkv = nn.Linear(\n            self.hidden_size,\n            (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,\n            bias=config.bias,\n        )\n\n        self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)\n        self._init_rope()\n\n    def _init_rope(self):\n        if self.config.rope_scaling is None:\n            self.rotary_emb = InternLM2RotaryEmbedding(\n                self.head_dim,\n                max_position_embeddings=self.max_position_embeddings,\n                base=self.config.rope_theta,\n            )\n        else:\n            scaling_type = self.config.rope_scaling[\"type\"]\n            scaling_factor = self.config.rope_scaling[\"factor\"]\n            if scaling_type == \"dynamic\":\n                self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(\n                    self.head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    base=self.config.rope_theta,\n                    scaling_factor=scaling_factor,\n                )\n            elif scaling_type == \"linear\":\n                self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(\n                    self.head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    base=self.config.rope_theta,\n                    scaling_factor=scaling_factor,\n                )\n            else:\n                raise ValueError(\"Currently we only support rotary embedding's type being 'dynamic' or 'linear'.\")\n        return self.rotary_emb\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\")\n\n        bsz, q_len, _ = hidden_states.size()\n\n        qkv_states = self.wqkv(hidden_states)\n\n        qkv_states = rearrange(\n            qkv_states,\n            \"b q (h gs d) -> b q h gs d\",\n            gs=2 + self.num_key_value_groups,\n            d=self.head_dim,\n        )\n\n        query_states = qkv_states[..., : self.num_key_value_groups, :]\n        query_states = rearrange(query_states, \"b q h gs d -> b q (h gs) d\")\n        key_states = qkv_states[..., -2, :]\n        value_states = qkv_states[..., -1, :]\n\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            kv_seq_len += past_key_value[0].shape[-2]\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n        if past_key_value is not None:\n            # reuse k, v, self_attention\n            key_states = torch.cat([past_key_value[0], key_states], dim=2)\n            value_states = torch.cat([past_key_value[1], value_states], dim=2)\n\n        past_key_value = (key_states, value_states) if use_cache else None\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n            raise ValueError(f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is {attn_weights.size()}\")\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\")\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n            raise ValueError(f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}\")\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n\n        attn_output = self.wo(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\n# Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2\nclass InternLM2FlashAttention2(InternLM2Attention):\n    \"\"\"\n    InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays\n    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of\n    flash attention and deal with padding tokens in case the input contains any of them.\n    \"\"\"\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.LongTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        # InternLM2FlashAttention2 attention does not support output_attentions\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\")\n\n            # overwrite attention_mask with padding_mask\n            attention_mask = kwargs.pop(\"padding_mask\")\n\n        output_attentions = False\n\n        bsz, q_len, _ = hidden_states.size()\n\n        qkv_states = self.wqkv(hidden_states)\n\n        qkv_states = rearrange(\n            qkv_states,\n            \"b q (h gs d) -> b q h gs d\",\n            gs=2 + self.num_key_value_groups,\n            d=self.head_dim,\n        )\n\n        query_states = qkv_states[..., : self.num_key_value_groups, :]\n        query_states = rearrange(query_states, \"b q h gs d -> b q (h gs) d\")\n        key_states = qkv_states[..., -2, :]\n        value_states = qkv_states[..., -1, :]\n\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            kv_seq_len += past_key_value[0].shape[-2]\n\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n        if past_key_value is not None:\n            # reuse k, v, self_attention\n            key_states = torch.cat([past_key_value[0], key_states], dim=2)\n            value_states = torch.cat([past_key_value[1], value_states], dim=2)\n\n        past_key_value = (key_states, value_states) if use_cache else None\n\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n\n        attn_output = self._flash_attention_forward(query_states, key_states, value_states, attention_mask, q_len)\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()\n        attn_output = self.wo(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n    def _flash_attention_forward(self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None):\n        \"\"\"\n        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token\n        first unpad the input, then computes the attention scores and pad the final attention scores.\n\n        Args:\n            query_states (`torch.Tensor`):\n                Input query states to be passed to Flash Attention API\n            key_states (`torch.Tensor`):\n                Input key states to be passed to Flash Attention API\n            value_states (`torch.Tensor`):\n                Input value states to be passed to Flash Attention API\n            attention_mask (`torch.Tensor`):\n                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the\n                position of padding tokens and 1 for the position of non-padding tokens.\n            dropout (`int`, *optional*):\n                Attention dropout\n            softmax_scale (`float`, *optional*):\n                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)\n        \"\"\"\n        # Contains at least one padding token in the sequence\n        causal = self.is_causal and query_length != 1\n        if attention_mask is not None:\n            batch_size = query_states.shape[0]\n            query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(query_states, key_states, value_states, attention_mask, query_length)\n\n            cu_seqlens_q, cu_seqlens_k = cu_seq_lens\n            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens\n\n            attn_output_unpad = flash_attn_varlen_func(\n                query_states,\n                key_states,\n                value_states,\n                cu_seqlens_q=cu_seqlens_q,\n                cu_seqlens_k=cu_seqlens_k,\n                max_seqlen_q=max_seqlen_in_batch_q,\n                max_seqlen_k=max_seqlen_in_batch_k,\n                dropout_p=dropout,\n                softmax_scale=softmax_scale,\n                causal=causal,\n            )\n\n            attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)\n        else:\n            attn_output = flash_attn_func(query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal)\n\n        return attn_output\n\n    def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):\n        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)\n        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape\n\n        key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)\n        value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)\n\n        if query_length == kv_seq_len:\n            query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k)\n            cu_seqlens_q = cu_seqlens_k\n            max_seqlen_in_batch_q = max_seqlen_in_batch_k\n            indices_q = indices_k\n        elif query_length == 1:\n            max_seqlen_in_batch_q = 1\n            cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query_layer.device)  # There is a memcpy here, that is very bad.\n            indices_q = cu_seqlens_q[:-1]\n            query_layer = query_layer.squeeze(1)\n        else:\n            # The -q_len: slice assumes left padding.\n            attention_mask = attention_mask[:, -query_length:]\n            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)\n\n        return (\n            query_layer,\n            key_layer,\n            value_layer,\n            indices_q.to(torch.int64),\n            (cu_seqlens_q, cu_seqlens_k),\n            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),\n        )\n\n\nINTERNLM2_ATTENTION_CLASSES = {\n    \"eager\": InternLM2Attention,\n    \"flash_attention_2\": InternLM2FlashAttention2,\n}\n\n\n# Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer\nclass InternLM2DecoderLayer(nn.Module):\n    def __init__(self, config: InternLM2Config):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config)\n\n        self.feed_forward = InternLM2MLP(config)\n        self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        **kwargs,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*):\n                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,\n                query_sequence_length, key_sequence_length)` if default attention is used.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\")\n\n        residual = hidden_states\n\n        hidden_states = self.attention_norm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.attention(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            **kwargs,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.ffn_norm(hidden_states)\n        hidden_states = self.feed_forward(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nInternLM2_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`InternLM2Config`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2\n@add_start_docstrings(\n    \"The bare InternLM2 Model outputting raw hidden-states without any specific head on top.\",\n    InternLM2_START_DOCSTRING,\n)\nclass InternLM2PreTrainedModel(PreTrainedModel):\n    config_class = InternLM2Config\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"InternLM2DecoderLayer\"]\n    _skip_keys_device_placement = \"past_key_values\"\n    _supports_flash_attn_2 = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nInternLM2_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or\n            when `config.use_cache=True`):\n            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape\n            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape\n            `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.\n\n            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\"\"\"\n\n\n# Modified from transformers.model.llama.modeling_llama.LlamaModel\n@add_start_docstrings(\n    \"The bare InternLM2 Model outputting raw hidden-states without any specific head on top.\",\n    InternLM2_START_DOCSTRING,\n)\nclass InternLM2Model(InternLM2PreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]\n\n    Args:\n        config: InternLM2Config\n    \"\"\"\n\n    _auto_class = \"AutoModel\"\n\n    def __init__(self, config: InternLM2Config):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.config = config\n        if not has_flash_attn:\n            self.config.attn_implementation = \"eager\"\n            print(\"Warning: Flash attention is not available, using eager attention instead.\")\n\n        self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n\n        self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])\n        self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.tok_embeddings\n\n    def set_input_embeddings(self, value):\n        self.tok_embeddings = value\n\n    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):\n        # create causal mask\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        combined_attention_mask = None\n        if input_shape[-1] > 1:\n            combined_attention_mask = _make_causal_mask(\n                input_shape,\n                inputs_embeds.dtype,\n                device=inputs_embeds.device,\n                past_key_values_length=past_key_values_length,\n            )\n\n        if attention_mask is not None:\n            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)\n            combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n\n        return combined_attention_mask\n\n    @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if self.config.attn_implementation == \"flash_attention_2\":\n            _import_flash_attn()\n\n        # retrieve input_ids and inputs_embeds\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape[:2]\n        elif inputs_embeds is not None:\n            batch_size, seq_length = inputs_embeds.shape[:2]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        seq_length_with_past = seq_length\n        past_key_values_length = 0\n        if past_key_values is not None:\n            past_key_values_length = past_key_values[0][0].shape[2]\n            seq_length_with_past = seq_length_with_past + past_key_values_length\n\n        if position_ids is None:\n            device = input_ids.device if input_ids is not None else inputs_embeds.device\n            position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device)\n            position_ids = position_ids.unsqueeze(0)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.tok_embeddings(input_ids)\n\n        if self.config.attn_implementation == \"flash_attention_2\":\n            # 2d mask is passed through the layers\n            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None\n        else:\n            if attention_mask is None:\n                attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device)\n            attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length)\n\n        # embed positions\n        hidden_states = inputs_embeds\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\")\n                use_cache = False\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = () if use_cache else None\n\n        for idx, decoder_layer in enumerate(self.layers):\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            past_key_value = past_key_values[idx] if past_key_values is not None else None\n\n            if self.gradient_checkpointing and self.training:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        # None for past_key_value\n                        return module(*inputs, output_attentions, None)\n\n                    return custom_forward\n\n                layer_outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(decoder_layer),\n                    hidden_states,\n                    attention_mask,\n                    position_ids,\n                    None,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=attention_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_value,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = next_decoder_cache if use_cache else None\n        if not return_dict:\n            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n\n# Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM\nclass InternLM2ForCausalLM(InternLM2PreTrainedModel):\n    _auto_class = \"AutoModelForCausalLM\"\n\n    _tied_weights_keys = [\"output.weight\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = InternLM2Model(config)\n        self.vocab_size = config.vocab_size\n        self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.tok_embeddings\n\n    def set_input_embeddings(self, value):\n        self.model.tok_embeddings = value\n\n    def get_output_embeddings(self):\n        return self.output\n\n    def set_output_embeddings(self, new_embeddings):\n        self.output = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)\n    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, InternLM2ForCausalLM\n\n        >>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)\n        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = outputs[0]\n        logits = self.output(hidden_states)\n        logits = logits.float()\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        device = input_ids.device if input_ids is not None else inputs_embeds.device\n        output = CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n        output[\"logits\"] = output[\"logits\"].to(device)\n        return output\n\n    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):\n        if past_key_values is not None:\n            past_length = past_key_values[0][0].shape[2]\n\n            # Some generation methods already pass only the last input ID\n            if input_ids.shape[1] > past_length:\n                remove_prefix_length = past_length\n            else:\n                # Default to old behavior: keep only final ID\n                remove_prefix_length = input_ids.shape[1] - 1\n\n            input_ids = input_ids[:, remove_prefix_length:]\n\n        position_ids = kwargs.get(\"position_ids\", None)\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -input_ids.shape[1] :]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"position_ids\": position_ids,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n                \"attention_mask\": attention_mask,\n            }\n        )\n        return model_inputs\n\n    @staticmethod\n    def _reorder_cache(past_key_values, beam_idx):\n        reordered_past = ()\n        for layer_past in past_key_values:\n            reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)\n        return reordered_past\n\n    def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=\"\"):\n        if tokenizer.add_bos_token:\n            prompt = \"\"\n        else:\n            prompt = tokenizer.bos_token\n        if meta_instruction:\n            prompt += f\"\"\"<|im_start|>system\\n{meta_instruction}<|im_end|>\\n\"\"\"\n        for record in history:\n            prompt += f\"\"\"<|im_start|>user\\n{record[0]}<|im_end|>\\n<|im_start|>assistant\\n{record[1]}<|im_end|>\\n\"\"\"\n        prompt += f\"\"\"<|im_start|>user\\n{query}<|im_end|>\\n<|im_start|>assistant\\n\"\"\"\n        return tokenizer([prompt], return_tensors=\"pt\")\n\n    @torch.no_grad()\n    def chat(\n        self,\n        tokenizer,\n        query: str,\n        history: List[Tuple[str, str]] = [],\n        streamer: Optional[BaseStreamer] = None,\n        max_new_tokens: int = 1024,\n        do_sample: bool = True,\n        temperature: float = 0.8,\n        top_p: float = 0.8,\n        meta_instruction: str = \"You are an AI assistant whose name is InternLM (书生·浦语).\\n\"\n        \"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\\n\"\n        \"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.\",\n        **kwargs,\n    ):\n        inputs = self.build_inputs(tokenizer, query, history, meta_instruction)\n        inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}\n        # also add end-of-assistant token in eos token id to avoid unnecessary generation\n        eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids([\"<|im_end|>\"])[0]]\n        outputs = self.generate(\n            **inputs,\n            streamer=streamer,\n            max_new_tokens=max_new_tokens,\n            do_sample=do_sample,\n            temperature=temperature,\n            top_p=top_p,\n            eos_token_id=eos_token_id,\n            **kwargs,\n        )\n        outputs = outputs[0].cpu().tolist()[len(inputs[\"input_ids\"][0]) :]\n        response = tokenizer.decode(outputs, skip_special_tokens=True)\n        response = response.split(\"<|im_end|>\")[0]\n        history = history + [(query, response)]\n        return response, history\n\n    @torch.no_grad()\n    def stream_chat(\n        self,\n        tokenizer,\n        query: str,\n        history: List[Tuple[str, str]] = [],\n        max_new_tokens: int = 1024,\n        do_sample: bool = True,\n        temperature: float = 0.8,\n        top_p: float = 0.8,\n        **kwargs,\n    ):\n        \"\"\"\n        Return a generator in format: (response, history)\n        Eg.\n        ('你好，有什么可以帮助您的吗', [('你好', '你好，有什么可以帮助您的吗')])\n        ('你好，有什么可以帮助您的吗？', [('你好', '你好，有什么可以帮助您的吗？')])\n        \"\"\"\n        if BaseStreamer is None:\n            raise ModuleNotFoundError(\"The version of `transformers` is too low. Please make sure that you have installed `transformers>=4.28.0`.\")\n\n        response_queue = queue.Queue(maxsize=20)\n\n        class ChatStreamer(BaseStreamer):\n            def __init__(self, tokenizer) -> None:\n                super().__init__()\n                self.tokenizer = tokenizer\n                self.queue = response_queue\n                self.query = query\n                self.history = history\n                self.response = \"\"\n                self.cache = []\n                self.received_inputs = False\n                self.queue.put((self.response, history + [(self.query, self.response)]))\n\n            def put(self, value):\n                if len(value.shape) > 1 and value.shape[0] > 1:\n                    raise ValueError(\"ChatStreamer only supports batch size 1\")\n                elif len(value.shape) > 1:\n                    value = value[0]\n\n                if not self.received_inputs:\n                    # The first received value is input_ids, ignore here\n                    self.received_inputs = True\n                    return\n\n                self.cache.extend(value.tolist())\n                token = self.tokenizer.decode(self.cache, skip_special_tokens=True)\n                if token.strip() != \"<|im_end|>\":\n                    self.response = self.response + token\n                    history = self.history + [(self.query, self.response)]\n                    self.queue.put((self.response, history))\n                    self.cache = []\n                else:\n                    self.end()\n\n            def end(self):\n                self.queue.put(None)\n\n        def stream_producer():\n            return self.chat(\n                tokenizer=tokenizer,\n                query=query,\n                streamer=ChatStreamer(tokenizer=tokenizer),\n                history=history,\n                max_new_tokens=max_new_tokens,\n                do_sample=do_sample,\n                temperature=temperature,\n                top_p=top_p,\n                **kwargs,\n            )\n\n        def consumer():\n            producer = threading.Thread(target=stream_producer)\n            producer.start()\n            while True:\n                res = response_queue.get()\n                if res is None:\n                    return\n                yield res\n\n        return consumer()\n\n\n# Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2\n@add_start_docstrings(\n    \"\"\"\n    The InternLM2 Model transformer with a sequence classification head on top (linear layer).\n\n    [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification,\n    as other causal models (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\",\n    InternLM2_START_DOCSTRING,\n)\nclass InternLM2ForSequenceClassification(InternLM2PreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = InternLM2Model(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.tok_embeddings\n\n    def set_input_embeddings(self, value):\n        self.model.tok_embeddings = value\n\n    @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        transformer_outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\"Cannot handle batch sizes > 1 if no padding token is defined.\")\n        if self.config.pad_token_id is None:\n            sequence_lengths = -1\n        else:\n            if input_ids is not None:\n                sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(logits.device)\n            else:\n                sequence_lengths = -1\n\n        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]\n\n        loss = None\n        if labels is not None:\n            labels = labels.to(logits.device)\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(pooled_logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(pooled_logits, labels)\n        if not return_dict:\n            output = (pooled_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n"
  },
  {
    "path": "siirl/models/transformers/internvl_chat/modeling_internvl_chat.py",
    "content": "# --------------------------------------------------------\n# InternVL\n# Copyright (c) 2024 OpenGVLab\n# Licensed under The MIT License [see LICENSE for details]\n# --------------------------------------------------------\n\nimport warnings\nfrom typing import List, Optional, Tuple, Union\n\nimport torch.utils.checkpoint\nimport transformers\nfrom torch import nn\nfrom torch.nn import CrossEntropyLoss\nfrom transformers import AutoModel, GenerationConfig, LlamaForCausalLM, LlamaTokenizer\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.utils import ModelOutput, logging\n\nfrom .configuration_internvl_chat import InternVLChatConfig\nfrom .modeling_intern_vit import InternVisionModel, has_flash_attn\nfrom .modeling_internlm2 import InternLM2ForCausalLM\n\nfrom siirl.models.transformers.internvl import get_conv_template\n\nlogger = logging.get_logger(__name__)\n\n\ndef version_cmp(v1, v2, op=\"eq\"):\n    import operator\n\n    from packaging import version\n\n    op_func = getattr(operator, op)\n    return op_func(version.parse(v1), version.parse(v2))\n\n\nclass InternVLChatModel(PreTrainedModel):\n    config_class = InternVLChatConfig\n    main_input_name = \"pixel_values\"\n    base_model_prefix = \"language_model\"\n    _supports_flash_attn_2 = True\n    _no_split_modules = [\"InternVisionEncoderLayer\", \"LlamaDecoderLayer\", \"InternLM2DecoderLayer\"]\n    supports_gradient_checkpointing = True\n\n    def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):\n        super().__init__(config)\n\n        assert version_cmp(transformers.__version__, \"4.37.0\", \"ge\")\n        image_size = config.force_image_size or config.vision_config.image_size\n        patch_size = config.vision_config.patch_size\n        self.patch_size = patch_size\n        self.select_layer = config.select_layer\n        self.template = config.template\n        self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio**2))\n        self.downsample_ratio = config.downsample_ratio\n        self.ps_version = config.ps_version\n        use_flash_attn = use_flash_attn if has_flash_attn else False\n        config.vision_config.use_flash_attn = True if use_flash_attn else False\n        config.llm_config.attn_implementation = \"flash_attention_2\" if use_flash_attn else \"eager\"\n\n        logger.info(f\"num_image_token: {self.num_image_token}\")\n        logger.info(f\"ps_version: {self.ps_version}\")\n        if vision_model is not None:\n            self.vision_model = vision_model\n        else:\n            self.vision_model = InternVisionModel(config.vision_config)\n        if language_model is not None:\n            self.language_model = language_model\n        else:\n            if config.llm_config.architectures[0] == \"LlamaForCausalLM\":\n                self.language_model = LlamaForCausalLM(config.llm_config)\n            elif config.llm_config.architectures[0] == \"InternLM2ForCausalLM\":\n                self.language_model = InternLM2ForCausalLM(config.llm_config)\n            else:\n                raise NotImplementedError(f\"{config.llm_config.architectures[0]} is not implemented.\")\n\n        vit_hidden_size = config.vision_config.hidden_size\n        llm_hidden_size = config.llm_config.hidden_size\n\n        self.mlp1 = nn.Sequential(nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size), nn.GELU(), nn.Linear(llm_hidden_size, llm_hidden_size))\n\n        self.img_context_token_id = None\n        self.conv_template = get_conv_template(self.template)\n        self.system_message = self.conv_template.system_message\n\n    def forward(\n        self,\n        pixel_values: torch.FloatTensor,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        image_flags: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        verbose: bool = False,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        image_flags = image_flags.squeeze(-1)\n        input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()\n\n        vit_embeds = self.extract_feature(pixel_values)\n        vit_embeds = vit_embeds[image_flags == 1]\n        vit_batch_size = pixel_values.shape[0]\n\n        B, N, C = input_embeds.shape\n        input_embeds = input_embeds.reshape(B * N, C)\n\n        if verbose and torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:\n            print(f\"dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}\")\n\n        input_ids = input_ids.reshape(B * N)\n        selected = input_ids == self.img_context_token_id\n        try:\n            input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)\n        except Exception as e:\n            vit_embeds = vit_embeds.reshape(-1, C)\n            print(f\"warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, vit_embeds.shape={vit_embeds.shape}\")\n            n_token = selected.sum()\n            pad_size = n_token - vit_embeds.size(0)\n            if pad_size > 0:\n                vit_embeds = nn.functional.pad(vit_embeds, (0, 0, 0, pad_size))\n            input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]\n\n        input_embeds = input_embeds.reshape(B, N, C)\n\n        outputs = self.language_model(\n            inputs_embeds=input_embeds,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        logits = outputs.logits\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def pixel_shuffle(self, x, scale_factor=0.5):\n        n, w, h, c = x.size()\n        # N, W, H, C --> N, W, H * scale, C // scale\n        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))\n        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale\n        x = x.permute(0, 2, 1, 3).contiguous()\n        # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)\n        x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)))\n        if self.ps_version == \"v1\":\n            warnings.warn(\"In ps_version 'v1', the height and width have not been swapped back, which results in a transposed image.\")\n        else:\n            x = x.permute(0, 2, 1, 3).contiguous()\n        return x\n\n    def extract_feature(self, pixel_values):\n        if self.select_layer == -1:\n            vit_embeds = self.vision_model(pixel_values=pixel_values, output_hidden_states=False, return_dict=True).last_hidden_state\n        else:\n            vit_embeds = self.vision_model(pixel_values=pixel_values, output_hidden_states=True, return_dict=True).hidden_states[self.select_layer]\n        vit_embeds = vit_embeds[:, 1:, :]\n\n        h = w = int(vit_embeds.shape[1] ** 0.5)\n        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)\n        vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)\n        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])\n        vit_embeds = self.mlp1(vit_embeds)\n        return vit_embeds\n\n    def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None, history=None, return_history=False, IMG_START_TOKEN=\"<img>\", IMG_END_TOKEN=\"</img>\", IMG_CONTEXT_TOKEN=\"<IMG_CONTEXT>\", verbose=False, image_counts=None):\n        if history is not None or return_history:\n            print(\"Now multi-turn chat is not supported in batch_chat.\")\n            raise NotImplementedError\n\n        if image_counts is not None:\n            num_patches_list = image_counts\n            print(\"Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.\")\n\n        img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)\n        self.img_context_token_id = img_context_token_id\n\n        if verbose and pixel_values is not None:\n            image_bs = pixel_values.shape[0]\n            print(f\"dynamic ViT batch size: {image_bs}\")\n\n        queries = []\n        for idx, num_patches in enumerate(num_patches_list):\n            question = questions[idx]\n            if pixel_values is not None and \"<image>\" not in question:\n                question = \"<image>\\n\" + question\n            template = get_conv_template(self.template)\n            template.system_message = self.system_message\n            template.append_message(template.roles[0], question)\n            template.append_message(template.roles[1], None)\n            query = template.get_prompt()\n\n            image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN\n            query = query.replace(\"<image>\", image_tokens, 1)\n            queries.append(query)\n\n        tokenizer.padding_side = \"left\"\n        model_inputs = tokenizer(queries, return_tensors=\"pt\", padding=True)\n        input_ids = model_inputs[\"input_ids\"].to(self.device)\n        attention_mask = model_inputs[\"attention_mask\"].to(self.device)\n        eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())\n        generation_config[\"eos_token_id\"] = eos_token_id\n        generation_output = self.generate(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, **generation_config)\n        responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)\n        responses = [response.split(template.sep.strip())[0].strip() for response in responses]\n        return responses\n\n    def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, num_patches_list=None, IMG_START_TOKEN=\"<img>\", IMG_END_TOKEN=\"</img>\", IMG_CONTEXT_TOKEN=\"<IMG_CONTEXT>\", verbose=False):\n        if history is None and pixel_values is not None and \"<image>\" not in question:\n            question = \"<image>\\n\" + question\n\n        if num_patches_list is None:\n            num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []\n        assert pixel_values is None or len(pixel_values) == sum(num_patches_list)\n\n        img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)\n        self.img_context_token_id = img_context_token_id\n\n        template = get_conv_template(self.template)\n        template.system_message = self.system_message\n        eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())\n\n        history = [] if history is None else history\n        for old_question, old_answer in history:\n            template.append_message(template.roles[0], old_question)\n            template.append_message(template.roles[1], old_answer)\n        template.append_message(template.roles[0], question)\n        template.append_message(template.roles[1], None)\n        query = template.get_prompt()\n\n        if verbose and pixel_values is not None:\n            image_bs = pixel_values.shape[0]\n            print(f\"dynamic ViT batch size: {image_bs}\")\n\n        for num_patches in num_patches_list:\n            image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN\n            query = query.replace(\"<image>\", image_tokens, 1)\n\n        model_inputs = tokenizer(query, return_tensors=\"pt\")\n        input_ids = model_inputs[\"input_ids\"].to(self.device)\n        attention_mask = model_inputs[\"attention_mask\"].to(self.device)\n        generation_config[\"eos_token_id\"] = eos_token_id\n        generation_output = self.generate(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, **generation_config)\n        response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]\n        response = response.split(template.sep.strip())[0].strip()\n        history.append((question, response))\n        if return_history:\n            return response, history\n        else:\n            query_to_print = query.replace(IMG_CONTEXT_TOKEN, \"\")\n            query_to_print = query_to_print.replace(f\"{IMG_START_TOKEN}{IMG_END_TOKEN}\", \"<image>\")\n            if verbose:\n                print(query_to_print, response)\n            return response\n\n    @torch.no_grad()\n    def generate(\n        self,\n        pixel_values: Optional[torch.FloatTensor] = None,\n        input_ids: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.LongTensor] = None,\n        visual_features: Optional[torch.FloatTensor] = None,\n        generation_config: Optional[GenerationConfig] = None,\n        output_hidden_states: Optional[bool] = None,\n        **generate_kwargs,\n    ) -> torch.LongTensor:\n        assert self.img_context_token_id is not None\n        if pixel_values is not None:\n            if visual_features is not None:\n                vit_embeds = visual_features\n            else:\n                vit_embeds = self.extract_feature(pixel_values)\n            input_embeds = self.language_model.get_input_embeddings()(input_ids)\n            B, N, C = input_embeds.shape\n            input_embeds = input_embeds.reshape(B * N, C)\n\n            input_ids = input_ids.reshape(B * N)\n            selected = input_ids == self.img_context_token_id\n            assert selected.sum() != 0\n            input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)\n\n            input_embeds = input_embeds.reshape(B, N, C)\n        else:\n            input_embeds = self.language_model.get_input_embeddings()(input_ids)\n\n        outputs = self.language_model.generate(\n            inputs_embeds=input_embeds,\n            attention_mask=attention_mask,\n            generation_config=generation_config,\n            output_hidden_states=output_hidden_states,\n            use_cache=True,\n            **generate_kwargs,\n        )\n\n        return outputs\n"
  },
  {
    "path": "siirl/models/transformers/internvl_chat/tokenization_internlm2.py",
    "content": "# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on transformers/src/transformers/models/llama/tokenization_llama.py\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tokenization classes for InternLM.\"\"\"\n\nimport os\nfrom shutil import copyfile\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport sentencepiece as spm\nfrom transformers.tokenization_utils import PreTrainedTokenizer\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\nVOCAB_FILES_NAMES = {\"vocab_file\": \"./tokenizer.model\"}\n\nPRETRAINED_VOCAB_FILES_MAP = {}\n\n\n# Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer\nclass InternLM2Tokenizer(PreTrainedTokenizer):\n    \"\"\"\n    Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.\n\n    Args:\n        vocab_file (`str`):\n            Path to the vocabulary file.\n    \"\"\"\n\n    vocab_files_names = VOCAB_FILES_NAMES\n    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP\n    model_input_names = [\"input_ids\", \"attention_mask\"]\n    _auto_class = \"AutoTokenizer\"\n\n    def __init__(\n        self,\n        vocab_file,\n        unk_token=\"<unk>\",\n        bos_token=\"<s>\",\n        eos_token=\"</s>\",\n        pad_token=\"</s>\",\n        sp_model_kwargs: Optional[Dict[str, Any]] = None,\n        add_bos_token=True,\n        add_eos_token=False,\n        decode_with_prefix_space=False,\n        clean_up_tokenization_spaces=False,\n        **kwargs,\n    ):\n        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs\n        self.vocab_file = vocab_file\n        self.add_bos_token = add_bos_token\n        self.add_eos_token = add_eos_token\n        self.decode_with_prefix_space = decode_with_prefix_space\n        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)\n        self.sp_model.Load(vocab_file)\n        self._no_prefix_space_tokens = None\n        super().__init__(\n            bos_token=bos_token,\n            eos_token=eos_token,\n            unk_token=unk_token,\n            pad_token=pad_token,\n            clean_up_tokenization_spaces=clean_up_tokenization_spaces,\n            **kwargs,\n        )\n\n    @property\n    def no_prefix_space_tokens(self):\n        if self._no_prefix_space_tokens is None:\n            vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))\n            self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith(\"▁\")}\n        return self._no_prefix_space_tokens\n\n    @property\n    def vocab_size(self):\n        \"\"\"Returns vocab size\"\"\"\n        return self.sp_model.get_piece_size()\n\n    @property\n    def bos_token_id(self) -> Optional[int]:\n        return self.sp_model.bos_id()\n\n    @property\n    def eos_token_id(self) -> Optional[int]:\n        return self.sp_model.eos_id()\n\n    def get_vocab(self):\n        \"\"\"Returns vocab as a dict\"\"\"\n        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}\n        vocab.update(self.added_tokens_encoder)\n        return vocab\n\n    def _tokenize(self, text):\n        \"\"\"Returns a tokenized string.\"\"\"\n        return self.sp_model.encode(text, out_type=str)\n\n    def _convert_token_to_id(self, token):\n        \"\"\"Converts a token (str) in an id using the vocab.\"\"\"\n        return self.sp_model.piece_to_id(token)\n\n    def _convert_id_to_token(self, index):\n        \"\"\"Converts an index (integer) in a token (str) using the vocab.\"\"\"\n        token = self.sp_model.IdToPiece(index)\n        return token\n\n    def _maybe_add_prefix_space(self, tokens, decoded):\n        if tokens and tokens[0] not in self.no_prefix_space_tokens:\n            return \" \" + decoded\n        else:\n            return decoded\n\n    def convert_tokens_to_string(self, tokens):\n        \"\"\"Converts a sequence of tokens (string) in a single string.\"\"\"\n        current_sub_tokens = []\n        out_string = \"\"\n        prev_is_special = False\n        for token in tokens:\n            # make sure that special tokens are not decoded using sentencepiece model\n            if token in self.all_special_tokens:\n                if not prev_is_special:\n                    out_string += \" \"\n                out_string += self.sp_model.decode(current_sub_tokens) + token\n                prev_is_special = True\n                current_sub_tokens = []\n            else:\n                current_sub_tokens.append(token)\n                prev_is_special = False\n        out_string += self.sp_model.decode(current_sub_tokens)\n        out_string = self.clean_up_tokenization(out_string)\n        out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)\n        return out_string[1:]\n\n    def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:\n        \"\"\"\n        Save the vocabulary and special tokens file to a directory.\n\n        Args:\n            save_directory (`str`):\n                The directory in which to save the vocabulary.\n\n        Returns:\n            `Tuple(str)`: Paths to the files saved.\n        \"\"\"\n        if not os.path.isdir(save_directory):\n            logger.error(f\"Vocabulary path ({save_directory}) should be a directory\")\n            return\n        out_vocab_file = os.path.join(save_directory, (filename_prefix + \"-\" if filename_prefix else \"\") + VOCAB_FILES_NAMES[\"vocab_file\"])\n\n        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):\n            copyfile(self.vocab_file, out_vocab_file)\n        elif not os.path.isfile(self.vocab_file):\n            with open(out_vocab_file, \"wb\") as fi:\n                content_spiece_model = self.sp_model.serialized_model_proto()\n                fi.write(content_spiece_model)\n\n        return (out_vocab_file,)\n\n    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):\n        if self.add_bos_token:\n            bos_token_ids = [self.bos_token_id]\n        else:\n            bos_token_ids = []\n\n        output = bos_token_ids + token_ids_0\n\n        if token_ids_1 is not None:\n            output = output + token_ids_1\n\n        if self.add_eos_token:\n            output = output + [self.eos_token_id]\n\n        return output\n\n    def get_special_tokens_mask(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False) -> List[int]:\n        \"\"\"\n        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding\n        special tokens using the tokenizer `prepare_for_model` method.\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of IDs.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n            already_has_special_tokens (`bool`, *optional*, defaults to `False`):\n                Whether or not the token list is already formatted with special tokens for the model.\n\n        Returns:\n            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.\n        \"\"\"\n        if already_has_special_tokens:\n            return super().get_special_tokens_mask(token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True)\n\n        if token_ids_1 is None:\n            return [1] + ([0] * len(token_ids_0)) + [1]\n        return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]\n\n    def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:\n        \"\"\"\n        Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make\n        use of token type ids, therefore a list of zeros is returned.\n\n        Args:\n            token_ids_0 (`List[int]`):\n                List of IDs.\n            token_ids_1 (`List[int]`, *optional*):\n                Optional second list of IDs for sequence pairs.\n\n        Returns:\n            `List[int]`: List of zeros.\n        \"\"\"\n        eos = [self.eos_token_id]\n\n        if token_ids_1 is None:\n            return len(token_ids_0 + eos) * [0]\n        return len(token_ids_0 + eos + token_ids_1 + eos) * [0]\n"
  },
  {
    "path": "siirl/models/transformers/internvl_chat/tokenization_internlm2_fast.py",
    "content": "# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tokenization Fast class for InternLM.\"\"\"\n\nimport os\nfrom shutil import copyfile\nfrom typing import Any, Dict, Optional, Tuple\n\nfrom tokenizers import Tokenizer, decoders, normalizers, processors\nfrom tokenizers.models import BPE\nfrom transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, SentencePieceExtractor, SpmConverter\nfrom transformers.tokenization_utils_fast import PreTrainedTokenizerFast\nfrom transformers.utils import logging\n\nfrom .tokenization_internlm2 import InternLM2Tokenizer\n\nlogger = logging.get_logger(__name__)\n\nVOCAB_FILES_NAMES = {\"vocab_file\": \"./tokenizer.model\"}\n\n\n# Modified from transformers.convert_slow_tokenizer.LlamaConverter\nclass InternLM2Converter(SpmConverter):\n    handle_byte_fallback = True\n\n    def vocab(self, proto):\n        vocab = [\n            (\"<unk>\", 0.0),\n            (\"<s>\", 0.0),\n            (\"</s>\", 0.0),\n        ]\n        vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]\n        return vocab\n\n    def unk_id(self, proto):\n        unk_id = 0\n        return unk_id\n\n    def decoder(self, replacement, add_prefix_space):\n        return decoders.Sequence(\n            [\n                decoders.Replace(\"▁\", \" \"),\n                decoders.ByteFallback(),\n                decoders.Fuse(),\n                decoders.Strip(content=\" \", left=1),\n            ]\n        )\n\n    def tokenizer(self, proto):\n        model_type = proto.trainer_spec.model_type\n        vocab_scores = self.vocab(proto)\n        # special tokens\n        added_tokens = self.original_tokenizer.added_tokens_decoder\n        for i in range(len(vocab_scores)):\n            piece, score = vocab_scores[i]\n            if i in added_tokens:\n                vocab_scores[i] = (added_tokens[i].content, score)\n        if model_type == 1:\n            raise RuntimeError(\"InternLM2 is supposed to be a BPE model!\")\n\n        elif model_type == 2:\n            _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)\n            bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}\n            tokenizer = Tokenizer(BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True))\n            tokenizer.add_special_tokens([added_token for index, added_token in added_tokens.items()])\n        else:\n            raise Exception(\"You're trying to run a `Unigram` model but you're file was trained with a different algorithm\")\n\n        return tokenizer\n\n    def normalizer(self, proto):\n        normalizers_list = []\n        if proto.normalizer_spec.add_dummy_prefix:\n            normalizers_list.append(normalizers.Prepend(prepend=\"▁\"))\n        normalizers_list.append(normalizers.Replace(pattern=\" \", content=\"▁\"))\n        return normalizers.Sequence(normalizers_list)\n\n    def pre_tokenizer(self, replacement, add_prefix_space):\n        return None\n\n\nSLOW_TO_FAST_CONVERTERS[\"InternLM2Tokenizer\"] = InternLM2Converter\n\n\n# Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast\nclass InternLM2TokenizerFast(PreTrainedTokenizerFast):\n    vocab_files_names = VOCAB_FILES_NAMES\n    slow_tokenizer_class = InternLM2Tokenizer\n    padding_side = \"left\"\n    model_input_names = [\"input_ids\", \"attention_mask\"]\n    _auto_class = \"AutoTokenizer\"\n\n    def __init__(\n        self,\n        vocab_file,\n        unk_token=\"<unk>\",\n        bos_token=\"<s>\",\n        eos_token=\"</s>\",\n        pad_token=\"</s>\",\n        sp_model_kwargs: Optional[Dict[str, Any]] = None,\n        add_bos_token=True,\n        add_eos_token=False,\n        decode_with_prefix_space=False,\n        clean_up_tokenization_spaces=False,\n        **kwargs,\n    ):\n        super().__init__(\n            vocab_file=vocab_file,\n            unk_token=unk_token,\n            bos_token=bos_token,\n            eos_token=eos_token,\n            pad_token=pad_token,\n            sp_model_kwargs=sp_model_kwargs,\n            add_bos_token=add_bos_token,\n            add_eos_token=add_eos_token,\n            decode_with_prefix_space=decode_with_prefix_space,\n            clean_up_tokenization_spaces=clean_up_tokenization_spaces,\n            **kwargs,\n        )\n        self._add_bos_token = add_bos_token\n        self._add_eos_token = add_eos_token\n        self.update_post_processor()\n        self.vocab_file = vocab_file\n\n    @property\n    def can_save_slow_tokenizer(self) -> bool:\n        return os.path.isfile(self.vocab_file) if self.vocab_file else False\n\n    def update_post_processor(self):\n        \"\"\"\n        Updates the underlying post processor with the current `bos_token` and `eos_token`.\n        \"\"\"\n        bos = self.bos_token\n        bos_token_id = self.bos_token_id\n        if bos is None and self.add_bos_token:\n            raise ValueError(\"add_bos_token = True but bos_token = None\")\n\n        eos = self.eos_token\n        eos_token_id = self.eos_token_id\n        if eos is None and self.add_eos_token:\n            raise ValueError(\"add_eos_token = True but eos_token = None\")\n\n        single = f\"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}\"\n        pair = f\"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}\"\n\n        special_tokens = []\n        if self.add_bos_token:\n            special_tokens.append((bos, bos_token_id))\n        if self.add_eos_token:\n            special_tokens.append((eos, eos_token_id))\n        self._tokenizer.post_processor = processors.TemplateProcessing(single=single, pair=pair, special_tokens=special_tokens)\n\n    @property\n    def add_eos_token(self):\n        return self._add_eos_token\n\n    @property\n    def add_bos_token(self):\n        return self._add_bos_token\n\n    @add_eos_token.setter\n    def add_eos_token(self, value):\n        self._add_eos_token = value\n        self.update_post_processor()\n\n    @add_bos_token.setter\n    def add_bos_token(self, value):\n        self._add_bos_token = value\n        self.update_post_processor()\n\n    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:\n        if not self.can_save_slow_tokenizer:\n            raise ValueError(\"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow tokenizer.\")\n\n        if not os.path.isdir(save_directory):\n            logger.error(f\"Vocabulary path ({save_directory}) should be a directory\")\n            return\n        out_vocab_file = os.path.join(save_directory, (filename_prefix + \"-\" if filename_prefix else \"\") + VOCAB_FILES_NAMES[\"vocab_file\"])\n\n        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):\n            copyfile(self.vocab_file, out_vocab_file)\n\n        return (out_vocab_file,)\n"
  },
  {
    "path": "siirl/models/transformers/kimi_vl.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom transformers.cache_utils import Cache\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward\n\nfrom siirl.utils.model_utils.ulysses import gather_heads_scatter_seq, gather_outpus_and_unpad, gather_seq_scatter_heads, get_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_rank, get_ulysses_sequence_parallel_world_size, validate_ulysses_config\n\n\ndef _merge_with_image_features(\n    self,\n    inputs_embeds: torch.Tensor,\n    input_ids: torch.Tensor,\n    image_features: torch.Tensor,\n):\n    \"\"\"\n    Args:\n        inputs_embeds (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length, input_embed_dim)`):\n            The input embeddings.\n        input_ids (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`):\n            The input ids.\n        image_features (:obj:`torch.Tensor` of shape :obj:`(image_token_nums, image_feature_dim)`):\n            The image features to merge with the input embeddings.\n    \"\"\"\n    image_token_index: int = self.config.media_placeholder_token_id\n\n    batch_size, sequence_length, input_embed_dim = inputs_embeds.shape\n    image_feature_nums, image_feature_dim = image_features.shape\n\n    assert image_feature_dim == input_embed_dim\n\n    image_token_nums = (input_ids == image_token_index).sum()\n    total_image_token_nums = torch.tensor([image_token_nums], dtype=image_token_nums.dtype, device=input_ids.device)\n    total_image_token_nums = gather_outpus_and_unpad(total_image_token_nums, gather_dim=0)  # [sp_size]\n    assert image_feature_nums == total_image_token_nums.sum()\n\n    # (batch_size, sequence_length / sp, input_embed_dim) -> (batch_size * sequence_length / sp, input_embed_dim)\n    inputs_embeds = inputs_embeds.reshape(-1, input_embed_dim)\n\n    # (batch_size, sequence_length / sp) -> (batch_size * sequence_length / sp)\n    input_ids = input_ids.flatten()\n\n    # split image features and fill in the image token positions if there are image tokens\n    sp_image_features = image_features.split(total_image_token_nums.tolist(), dim=0)\n    sp_rank = get_ulysses_sequence_parallel_rank()\n    image_features = sp_image_features[sp_rank]\n    inputs_embeds[input_ids == image_token_index] = image_features\n\n    inputs_embeds = inputs_embeds.reshape((batch_size, sequence_length, input_embed_dim))\n\n    return inputs_embeds\n\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\n# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`):\n            The position indices of the tokens corresponding to the query and key tensors. For example, this can be\n            used to pass offsetted position ids when working with a KV-cache.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos[position_ids].unsqueeze(unsqueeze_dim)\n    sin = sin[position_ids].unsqueeze(unsqueeze_dim)\n\n    b, h, s, d = q.shape\n    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n\n    b, h, s, d = k.shape\n    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\n# Copied from transformers.models.llama.modeling_llama.repeat_kv\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\ndef _ulysses_flash_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.LongTensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Cache] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    **kwargs,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n    bsz, q_len, _ = hidden_states.size()\n\n    if self.q_lora_rank is None:\n        q = self.q_proj(hidden_states)\n    else:\n        q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n    q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\n    q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)\n\n    # Flash attention requires the input to have the shape\n    # batch_size x seq_length x head_dim x hidden_dim\n    # therefore we just need to keep the original shape\n    compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n    compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)\n    k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\n    kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2)\n\n    k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)\n    kv_seq_len = value_states.shape[-2]\n\n    # patch to get all emb\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n    kv_seq_len *= ulysses_sp_size\n\n    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n    q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)\n\n    query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n    query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n    query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n    key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n    key_states[:, :, :, : self.qk_nope_head_dim] = k_nope\n    key_states[:, :, :, self.qk_nope_head_dim :] = k_pe\n\n    if self.q_head_dim != self.v_head_dim:\n        value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])\n\n    # patch\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.num_heads, ulysses_sp_size)\n\n        num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads\n        key_states = repeat_kv(key_states, num_key_value_groups)\n        value_states = repeat_kv(value_states, num_key_value_groups)\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n        # (batch_size, num_head / sp_size, seq_length, head_size)\n        full_q_len = query_states.size(2)  # full_q_len = seq_length\n\n        position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)]\n        torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group())\n        position_ids = torch.concat(position_ids_list, dim=-1)\n\n    else:\n        full_q_len = q_len\n\n    # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache\n    # to be able to avoid many of these transpose/reshape/view.\n    query_states = query_states.transpose(1, 2)\n    key_states = key_states.transpose(1, 2)\n    value_states = value_states.transpose(1, 2)\n\n    dropout_rate = self.attention_dropout if self.training else 0.0\n\n    attn_output = _flash_attention_forward(\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        full_q_len,\n        dropout=dropout_rate,\n        sliding_window=None,\n        is_causal=self.is_causal,\n        use_top_left_mask=self._flash_attn_uses_top_left_mask,\n        position_ids=position_ids,  # important: pass position ids\n        softmax_scale=self.softmax_scale,\n    )\n\n    if ulysses_sp_size > 1:\n        attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)\n\n    if self.q_head_dim != self.v_head_dim:\n        attn_output = attn_output[:, :, :, : self.v_head_dim]\n\n    attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim).contiguous()\n    attn_output = self.o_proj(attn_output)\n\n    return attn_output, None, None\n"
  },
  {
    "path": "siirl/models/transformers/llama.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport sys\nfrom dataclasses import dataclass\nfrom typing import Callable, List, Optional, Tuple, Union\n\nimport torch\n\nif sys.version_info >= (3, 11):\n    pass\nelse:\n    pass\n\nfrom transformers.cache_utils import Cache\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\nfrom transformers.models.llama.modeling_llama import apply_rotary_pos_emb\nfrom transformers.utils import logging\n\nfrom siirl.utils.model_utils.ulysses import (\n    gather_heads_scatter_seq,\n    gather_seq_scatter_heads,\n    get_ulysses_sequence_parallel_world_size,\n    validate_ulysses_config,\n)\n\nlogger = logging.get_logger(__name__)\n\n\ndef llama_flash_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.LongTensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Cache] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    cache_position: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46\n    **kwargs,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n    \"\"\"\n    Adapted from transformers 4.47.1 to support Ulysses sequence parallelism.\n\n    NOTE: This function is used for transformers versions in the range [4.45.0, 4.47.1].\n    \"\"\"\n    output_attentions = False\n\n    bsz, q_len, _ = hidden_states.size()\n\n    query_states = self.q_proj(hidden_states)\n    key_states = self.k_proj(hidden_states)\n    value_states = self.v_proj(hidden_states)\n\n    # Flash attention requires the input to have the shape\n    # batch_size x seq_length x head_dim x hidden_dim\n    # therefore we just need to keep the original shape\n    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n    # trade off: repeat first and then all to all\n    # key_states = repeat_kv(key_states, self.num_key_value_groups)\n    # value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n    ########## AlltoAll for Ulysses ##########\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.num_heads, ulysses_sp_size)\n\n        # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n\n    full_q_len = query_states.size(2)  # full seq length\n\n    if position_embeddings is None:\n        logger.warning_once(\n            \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n            \"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed \"\n            \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be \"\n            \"removed and `position_embeddings` will be mandatory.\"\n        )\n        cos, sin = self.rotary_emb(value_states, position_ids)\n    else:\n        cos, sin = position_embeddings\n    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n    if past_key_value is not None:\n        # sin and cos are specific to RoPE models; cache_position needed for the static cache\n        cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n    # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache\n    # to be able to avoid many of these transpose/reshape/view.\n    query_states = query_states.transpose(1, 2)\n    key_states = key_states.transpose(1, 2)\n    value_states = value_states.transpose(1, 2)\n\n    dropout_rate = self.attention_dropout if self.training else 0.0\n\n    # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n    # therefore the input hidden states gets silently casted in float32. Hence, we need\n    # cast them back in the correct dtype just to be sure everything works as expected.\n    # This might slowdown training & inference so it is recommended to not cast the LayerNorms\n    # in fp32. (LlamaRMSNorm handles it correctly)\n\n    input_dtype = query_states.dtype\n    if input_dtype == torch.float32:\n        if torch.is_autocast_enabled():\n            target_dtype = torch.get_autocast_gpu_dtype()\n        # Handle the case where the model is quantized\n        elif hasattr(self.config, \"_pre_quantization_dtype\"):\n            target_dtype = self.config._pre_quantization_dtype\n        else:\n            target_dtype = self.q_proj.weight.dtype\n\n        logger.warning_once(f\"The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in {target_dtype}.\")\n\n        query_states = query_states.to(target_dtype)\n        key_states = key_states.to(target_dtype)\n        value_states = value_states.to(target_dtype)\n\n    attn_output = _flash_attention_forward(\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        full_q_len,\n        position_ids=position_ids,\n        dropout=dropout_rate,\n        sliding_window=getattr(self, \"sliding_window\", None),\n        use_top_left_mask=self._flash_attn_uses_top_left_mask,\n        is_causal=self.is_causal,\n        **kwargs,\n    )\n\n    attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)\n    attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()\n    attn_output = self.o_proj(attn_output)\n\n    if not output_attentions:\n        attn_weights = None\n\n    return attn_output, attn_weights, past_key_value\n\n\ndef llama_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    position_embeddings: Tuple[torch.Tensor, torch.Tensor],\n    attention_mask: Optional[torch.Tensor],\n    past_key_value: Optional[Cache] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    **kwargs,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n    \"\"\"\n    Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.\n\n    NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0.\n    \"\"\"\n    from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\n    from transformers.models.llama.modeling_llama import eager_attention_forward\n\n    bsz, q_len, _ = hidden_states.shape\n\n    query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n    key_states = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n    value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n\n    ########## AlltoAll for Ulysses ##########\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size)\n\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n\n    full_q_len = query_states.size(2)\n\n    cos, sin = position_embeddings\n    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n    if past_key_value is not None:\n        # sin and cos are specific to RoPE models; cache_position needed for the static cache\n        cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n    attention_interface: Callable = eager_attention_forward\n    if self.config._attn_implementation != \"eager\":\n        if self.config._attn_implementation == \"sdpa\" and kwargs.get(\"output_attentions\", False):\n            logger.warning_once('`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.')\n        else:\n            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n    attn_output, attn_weights = attention_interface(\n        self,\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        dropout=0.0 if not self.training else self.attention_dropout,\n        scaling=self.scaling,\n        **kwargs,\n    )\n\n    attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)\n    attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()\n    attn_output = self.o_proj(attn_output)\n    return attn_output, attn_weights\n\n\n@dataclass\nclass CausalLMOutputForPPO(CausalLMOutputWithPast):\n    log_probs: Optional[torch.FloatTensor] = None\n    entropy: Optional[torch.FloatTensor] = None\n\n\ndef forward_for_ppo(\n    self,\n    input_ids: torch.LongTensor = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[Union[\"Cache\", List[torch.FloatTensor]]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    logits_to_keep: Union[int, torch.Tensor] = 0,\n    temperature: float = 1.0,\n    **loss_kwargs,\n) -> Union[Tuple, CausalLMOutputForPPO]:\n    r\"\"\"\n    Copy paste LLaMa's forward\n    https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/llama.py\n\n    This function should be generic enough for all pure text models.\n    ```\"\"\"\n    from siirl.utils.experimental.torch_functional import FusedLinearForPPO\n\n    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n    output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n    return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n    outputs = self.model(\n        input_ids=input_ids,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        return_dict=return_dict,\n        cache_position=cache_position,\n    )\n\n    hidden_states = outputs[0]\n\n    if not return_dict:\n        raise NotImplementedError(\"forward_for_ppo has to return_dict\")\n\n    # Loss calculations\n    if labels is not None:\n        rolled_labels = torch.roll(labels, shifts=-1, dims=-1)\n    elif input_ids is not None:\n        rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)\n    else:\n        raise RuntimeError(\"To use forward_for_ppo, either labels or input_ids must be provided.\")\n\n    fused_linear_for_ppo = FusedLinearForPPO()\n    log_probs, entropy = fused_linear_for_ppo.forward(\n        hidden_states=hidden_states,\n        vocab_weights=self.lm_head.weight,\n        input_ids=rolled_labels,\n        temperature=temperature,\n    )\n\n    return CausalLMOutputForPPO(\n        log_probs=log_probs,\n        entropy=entropy,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n    )\n"
  },
  {
    "path": "siirl/models/transformers/monkey_patch.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nApply monkey-patch function to models\n\"\"\"\n\nimport importlib.metadata\nimport sys\nfrom functools import lru_cache\nfrom typing import Optional\n\nimport torch\nfrom packaging import version\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward\nfrom transformers.modeling_utils import PreTrainedModel\n\nfrom siirl.utils.model_utils.ulysses import (\n    gather_heads_scatter_seq,\n    gather_seq_scatter_heads,\n    get_ulysses_sequence_parallel_group,\n    get_ulysses_sequence_parallel_world_size,\n    slice_input_tensor,\n)\n\nfrom loguru import logger\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=2, repeats=n_rep). The hidden states go from (batch,\n    seqlen, num_key_value_heads, head_dim) to (batch, seqlen, num_attention_heads, head_dim)\n    \"\"\"\n    batch, slen, num_key_value_heads, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim)\n    return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim)\n\n\ndef _ulysses_flash_attention_forward(\n    query_states: torch.Tensor,\n    key_states: torch.Tensor,\n    value_states: torch.Tensor,\n    *args,\n    position_ids: Optional[torch.Tensor] = None,\n    **kwargs,\n):\n    \"\"\"Insert all-to-all before and after flash attention.\n    DeepSpeed-Ulysses: https://arxiv.org/pdf/2309.14509\n\n    Args:\n        query_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads, head_dim)\n        key_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim)\n        value_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim)\n        position_ids (torch.Tensor, optional): (batch_size, seqlen/sp_size)\n\n    Returns:\n        torch.Tensor: (batch_size, seqlen/sp_size, nheads, head_dim)\n    \"\"\"\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        assert position_ids is not None, \"position_ids is required for Ulysses sequence parallelism\"\n\n        # NOTE: repeat kv heads to be divided by sequence parallel. Instead of repeating nheads_q//nheads_k,\n        # we choose to repeat sp_size//nheads_k, since flash_attention supports MQA/GQA.\n        # For example:\n        # - nheads_k=4, sp=8, repeats=2\n        # - nheads_k=8, sp=8, repeats=1\n        # - nheads_k=16, sp=8, repeats=1\n        repeats = max(ulysses_sp_size // key_states.size(2), 1)\n        key_states = repeat_kv(key_states, repeats)\n        value_states = repeat_kv(value_states, repeats)\n\n        # (bsz, seq_len/n, n_head, head_dim) -> (bsz, seq_len, n_head/n, head_dim)\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)\n\n        # TODO: all_gather position_ids because `prepare_fa2_from_position_ids` needs it, we can eliminate\n        # this all_gather by passing cu_seq_lens_q, cu_seq_lens_k, max_length_k, max_length_q explicitly.\n        # https://github.com/huggingface/transformers/pull/33932\n\n        # (bsz, seq_len/n) -> (bsz, seq_len)\n        position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)]\n        torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group())\n        position_ids = torch.concat(position_ids_list, dim=-1)\n\n    # (bsz, seq_len, n_head/n, head_dim)\n    attn_output = _flash_attention_forward(query_states, key_states, value_states, *args, position_ids=position_ids, **kwargs)\n\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        # (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim)\n        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)\n\n    return attn_output\n\n\ndef patch_vlm_for_ulysses_input_slicing(model_class: type):\n    \"\"\"\n    Applies a monkey patch to the forward method of a given model class\n    to enable Ulysses sequence parallelism input slicing.\n    \"\"\"\n\n    def _create_ulysses_wrapped_decoder_forward(original_forward):\n        def ulysses_wrapped_decoder_forward(self, *args, **kwargs):\n            inputs_embeds = kwargs.get(\"inputs_embeds\")\n            call_kwargs = kwargs.copy()\n\n            current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n            slice_now = inputs_embeds is not None and current_ulysses_sp_size > 1 and getattr(self, \"_needs_initial_slice\", True)\n            if slice_now:\n                call_kwargs[\"inputs_embeds\"] = slice_input_tensor(inputs_embeds, dim=1, padding=False)\n                self._needs_initial_slice = False\n            try:\n                return original_forward(self, *args, **call_kwargs)\n            finally:\n                if slice_now:\n                    self._needs_initial_slice = True\n\n        return ulysses_wrapped_decoder_forward\n\n    original_forward = model_class.forward\n    wrapped_forward = _create_ulysses_wrapped_decoder_forward(original_forward)\n    model_class.forward = wrapped_forward\n    print(f\"Monkey patch {model_class.__name__}.forward for Ulysses SP input slicing.\")\n\n\ndef apply_monkey_patch(\n    model: PreTrainedModel,\n    ulysses_sp_size: int = 1,\n    use_remove_padding: bool = True,\n    use_fused_kernels: bool = False,\n):\n    \"\"\"Replace _flash_attention_forward to _ulysses_flash_attention_forward\"\"\"\n    module = sys.modules[model.__module__]\n\n    try:\n        num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads\n    except AttributeError:\n        num_attention_heads, num_key_value_heads = model.config.text_config.num_attention_heads, model.config.text_config.num_key_value_heads\n\n    assert num_attention_heads % ulysses_sp_size == 0, f\"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}\"\n    assert num_key_value_heads % ulysses_sp_size == 0 or ulysses_sp_size % num_key_value_heads == 0, (\n        f\"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}or vise versa. Upon ulysses_sp_size % num_key_value_heads == 0,kv heads are repeated to ensure correctness.\"\n    )\n    # TODO: VLM models only, unify monkey patch to LLM models.\n    if model.config.model_type == \"qwen2_5_vl\":\n        if is_transformers_version_in_range(min_version=\"4.53.0\"):\n            from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention\n        else:\n            from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (\n                Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention,\n            )\n\n        if use_remove_padding or ulysses_sp_size > 1:\n            from siirl.models.transformers.qwen2_vl import ulysses_flash_attn_forward\n\n            Qwen2_5_VLAttention.forward = ulysses_flash_attn_forward\n            logger.info(\"Monkey patch FlashAttention2.forward in Qwen2.5VL\")\n\n        if ulysses_sp_size > 1:\n            if is_transformers_version_in_range(min_version=\"4.52.0\"):\n                from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel\n\n                patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel)\n            else:\n                from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel\n\n                patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLModel)\n\n    elif model.config.model_type == \"qwen2_vl\":\n        if is_transformers_version_in_range(min_version=\"4.53.0\"):\n            from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention\n        else:\n            from transformers.models.qwen2_vl.modeling_qwen2_vl import (\n                Qwen2VLFlashAttention2 as Qwen2VLAttention,\n                Qwen2VLForConditionalGeneration,\n            )\n\n        if use_remove_padding or ulysses_sp_size > 1:\n            from siirl.models.transformers.qwen2_vl import ulysses_flash_attn_forward\n\n            Qwen2VLAttention.forward = ulysses_flash_attn_forward\n            logger.info(\"Monkey patch FlashAttention2.forward in Qwen2VL\")\n\n        if ulysses_sp_size > 1:\n            if is_transformers_version_in_range(min_version=\"4.52.0\"):\n                from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel\n\n                patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel)\n            else:\n                from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel\n\n                patch_vlm_for_ulysses_input_slicing(Qwen2VLModel)\n\n        if use_fused_kernels:\n            from siirl.models.transformers.qwen2_vl import forward_for_ppo\n\n            Qwen2VLForConditionalGeneration.forward = forward_for_ppo\n\n        return\n\n    elif model.config.model_type == \"kimi_vl\":\n        if use_remove_padding or ulysses_sp_size > 1:\n            # TODO: Changes need to be made when transformers are adapted.\n            from siirl.models.transformers.kimi_vl import _ulysses_flash_attn_forward\n\n            module.DeepseekV3FlashAttention2.forward = _ulysses_flash_attn_forward\n            logger.info(\"Monkey patch FlashAttention2.forward in KimiVL\")\n\n        if ulysses_sp_size > 1:\n            patch_vlm_for_ulysses_input_slicing(module.DeepseekV3ForCausalLM)\n\n        if use_fused_kernels:\n            logger.warning(\"Not support fused kernels for KimiVL\")\n\n        return\n\n    # transformers<=4.47.1\n    if use_remove_padding or ulysses_sp_size > 1:\n        if hasattr(module, \"_flash_attention_forward\"):\n            module._flash_attention_forward = _ulysses_flash_attention_forward\n            logger.info(f\"Monkey patch _flash_attention_forward in {model.__module__}\")\n        else:\n            # transformers>=4.48.0\n            from transformers.integrations import flash_attention\n\n            flash_attention._flash_attention_forward = _ulysses_flash_attention_forward\n            logger.info(f\"Monkey patch _flash_attention_forward in {flash_attention.__name__}\")\n\n\n\n@lru_cache\ndef is_transformers_version_in_range(min_version: Optional[str] = None, max_version: Optional[str] = None) -> bool:\n    try:\n        # Get the installed version of the transformers library\n        transformers_version_str = importlib.metadata.version(\"transformers\")\n    except importlib.metadata.PackageNotFoundError as e:\n        raise ModuleNotFoundError(\"The `transformers` package is not installed.\") from e\n\n    transformers_version = version.parse(transformers_version_str)\n\n    lower_bound_check = True\n    if min_version is not None:\n        lower_bound_check = version.parse(min_version) <= transformers_version\n\n    upper_bound_check = True\n    if max_version is not None:\n        upper_bound_check = transformers_version <= version.parse(max_version)\n\n    return lower_bound_check and upper_bound_check\n"
  },
  {
    "path": "siirl/models/transformers/npu_patch.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Copyright 2025 The Qwen Team and The HuggingFace Inc. team\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Tuple\n\nimport torch\nimport torch_npu\nfrom torch_npu import npu_rotary_mul as apply_rotary_emb\nfrom transformers.models.qwen2_5_vl import modeling_qwen2_5_vl\nfrom transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2RMSNorm\nfrom transformers.models.qwen2 import modeling_qwen2\n\n\n# This patch takes effect when using apply_rotary_pos_emb_flashatt on qwen2_5_vl and will be removed in subsequent versions\n# https://github.com/huggingface/transformers/pull/38491\ndef apply_rotary_pos_emb_flashatt_npu(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n    cos = cos.chunk(2, dim=-1)[0].contiguous()\n    sin = sin.chunk(2, dim=-1)[0].contiguous()\n    cos = cos.repeat(1, 2)\n    sin = sin.repeat(1, 2)\n    q_embed = apply_rotary_emb(q.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float()).type_as(q)\n    k_embed = apply_rotary_emb(k.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float()).type_as(k)\n    return q_embed, k_embed\n\n\ndef apply_rotary_pos_emb_npu(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    q_embed = torch_npu.npu_rotary_mul(q, cos, sin)\n    k_embed = torch_npu.npu_rotary_mul(k, cos, sin)\n    return q_embed, k_embed\n\n\n# This api can improve performance on ASCEND NPU\ndef rms_norm_forward(self, x):\n    return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.variance_epsilon)[0]\n\n\nQwen2RMSNorm.forward = rms_norm_forward\nmodeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_npu\nmodeling_qwen2.Qwen2RMSNorm.forward = rms_norm_forward\nmodeling_qwen2.apply_rotary_pos_emb = apply_rotary_pos_emb_npu\n"
  },
  {
    "path": "siirl/models/transformers/qwen2.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Callable, Optional, Tuple\n\nimport torch\nfrom transformers.cache_utils import Cache\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward\nfrom transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv\nfrom transformers.utils import logging\n\nfrom siirl.utils.model_utils.ulysses import (\n    gather_heads_scatter_seq,\n    gather_seq_scatter_heads,\n    get_ulysses_sequence_parallel_world_size,\n    validate_ulysses_config,\n)\n\nlogger = logging.get_logger(__name__)\n\n\ndef qwen2_flash_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_value: Optional[Cache] = None,\n    output_attentions: bool = False,\n    use_cache: bool = False,\n    cache_position: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46\n):\n    \"\"\"\n    Adapted from transformers 4.47.1 to support Ulysses sequence parallelism.\n\n    NOTE: This function is only tested on transformers versions between 4.45.0 and 4.47.1.\n    \"\"\"\n    bsz, q_len, _ = hidden_states.size()\n\n    query_states = self.q_proj(hidden_states)\n    key_states = self.k_proj(hidden_states)\n    value_states = self.v_proj(hidden_states)\n\n    query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n    key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n    value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)\n\n    ########## AlltoAll for Ulysses ##########\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.num_heads, ulysses_sp_size)\n\n        # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n\n    full_q_len = query_states.size(2)  # full seq length\n\n    if position_embeddings is None:\n        logger.warning_once(\n            \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n            \"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed \"\n            \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be \"\n            \"removed and `position_embeddings` will be mandatory.\"\n        )\n        cos, sin = self.rotary_emb(value_states, position_ids)\n    else:\n        cos, sin = position_embeddings\n    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n    if past_key_value is not None:\n        cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n    # repeat k/v heads if n_kv_heads < n_heads\n    key_states = repeat_kv(key_states, self.num_key_value_groups)\n    value_states = repeat_kv(value_states, self.num_key_value_groups)\n    dropout_rate = 0.0 if not self.training else self.attention_dropout\n\n    # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n    # therefore the input hidden states gets silently casted in float32. Hence, we need\n    # cast them back in float16 just to be sure everything works as expected.\n    input_dtype = query_states.dtype\n    if input_dtype == torch.float32:\n        if torch.is_autocast_enabled():\n            target_dtype = torch.get_autocast_gpu_dtype()\n        # Handle the case where the model is quantized\n        elif hasattr(self.config, \"_pre_quantization_dtype\"):\n            target_dtype = self.config._pre_quantization_dtype\n        else:\n            target_dtype = self.q_proj.weight.dtype\n\n        logger.warning_once(f\"The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in {target_dtype}.\")\n\n        query_states = query_states.to(target_dtype)\n        key_states = key_states.to(target_dtype)\n        value_states = value_states.to(target_dtype)\n\n    # Reashape to the expected shape for Flash Attention\n    query_states = query_states.transpose(1, 2)\n    key_states = key_states.transpose(1, 2)\n    value_states = value_states.transpose(1, 2)\n\n    if self.config.use_sliding_window and getattr(self.config, \"sliding_window\", None) is not None and self.layer_idx >= self.config.max_window_layers:\n        sliding_window = self.config.sliding_window\n    else:\n        sliding_window = None\n\n    attn_output = _flash_attention_forward(\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        full_q_len,\n        position_ids=position_ids,\n        dropout=dropout_rate,\n        sliding_window=sliding_window,\n        is_causal=self.is_causal,\n        use_top_left_mask=self._flash_attn_uses_top_left_mask,\n    )\n\n    # use full_q_len to reshape\n    attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)\n    attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()\n    attn_output = self.o_proj(attn_output)\n\n    if not output_attentions:\n        attn_weights = None\n\n    return attn_output, attn_weights, past_key_value\n\n\ndef qwen2_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    position_embeddings: Tuple[torch.Tensor, torch.Tensor],\n    attention_mask: Optional[torch.Tensor],\n    past_key_value: Optional[Cache] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    **kwargs,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n    \"\"\"\n    Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.\n\n    NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0.\n    \"\"\"\n    from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\n\n    bsz, q_len, _ = hidden_states.shape\n    hidden_shape = (bsz, q_len, -1, self.head_dim)\n\n    query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n    key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n    value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\n    ########## AlltoAll for Ulysses ##########\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size)\n\n        # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n\n    full_q_len = query_states.size(2)\n\n    cos, sin = position_embeddings\n    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n    if past_key_value is not None:\n        # sin and cos are specific to RoPE models; cache_position needed for the static cache\n        cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n    sliding_window = None\n    if self.config.use_sliding_window and getattr(self.config, \"sliding_window\", None) is not None and self.layer_idx >= self.config.max_window_layers:\n        sliding_window = self.config.sliding_window\n\n    from transformers.models.qwen2.modeling_qwen2 import eager_attention_forward\n\n    attention_interface: Callable = eager_attention_forward\n    if self.config._attn_implementation != \"eager\":\n        if self.config._attn_implementation == \"sdpa\" and kwargs.get(\"output_attentions\", False):\n            logger.warning_once('`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.')\n        else:\n            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n    attn_output, attn_weights = attention_interface(\n        self,\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        dropout=0.0 if not self.training else self.attention_dropout,\n        scaling=self.scaling,\n        sliding_window=sliding_window,  # main diff with Llama\n        **kwargs,\n    )\n\n    attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()\n    ########## AlltoAll for Ulysses ##########\n    if ulysses_sp_size > 1:\n        # (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim)\n        attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)\n    attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()\n    attn_output = self.o_proj(attn_output)\n    return attn_output, attn_weights\n"
  },
  {
    "path": "siirl/models/transformers/qwen2_5_vl.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (\n    Qwen2_5_VLCausalLMOutputWithPast,\n    Qwen2_5_VLForConditionalGeneration,\n)\n\n\n@dataclass\nclass Qwen2_5_VLCausalLMOutputForPPO(Qwen2_5_VLCausalLMOutputWithPast):\n    log_probs: Optional[torch.FloatTensor] = None\n    entropy: Optional[torch.FloatTensor] = None\n\n\ndef forward_for_ppo(\n    self: Qwen2_5_VLForConditionalGeneration,\n    input_ids: torch.LongTensor = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[List[torch.FloatTensor]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    pixel_values: Optional[torch.Tensor] = None,\n    pixel_values_videos: Optional[torch.FloatTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n    rope_deltas: Optional[torch.LongTensor] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    second_per_grid_ts: Optional[torch.Tensor] = None,\n    temperature: float = 1.0,\n    **loss_kwargs,\n) -> Union[Tuple, Qwen2_5_VLCausalLMOutputForPPO]:\n    r\"\"\"\n    Copy paste Qwen2_5_VL's forward\n    https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_5_vl.py\n    ```\"\"\"\n    from siirl.utils.experimental.torch_functional import FusedLinearForPPO\n\n    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n    output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n    return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n    if inputs_embeds is None:\n        inputs_embeds = self.model.embed_tokens(input_ids)\n        if pixel_values is not None:\n            pixel_values = pixel_values.type(self.visual.dtype)\n            image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)\n            n_image_tokens = (input_ids == self.config.image_token_id).sum().item()\n            n_image_features = image_embeds.shape[0]\n            if n_image_tokens != n_image_features:\n                raise ValueError(f\"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}\")\n\n            mask = input_ids == self.config.image_token_id\n            mask_unsqueezed = mask.unsqueeze(-1)\n            mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)\n            image_mask = mask_expanded.to(inputs_embeds.device)\n\n            image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)\n            inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)\n\n        if pixel_values_videos is not None:\n            pixel_values_videos = pixel_values_videos.type(self.visual.dtype)\n            video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)\n            n_video_tokens = (input_ids == self.config.video_token_id).sum().item()\n            n_video_features = video_embeds.shape[0]\n            if n_video_tokens != n_video_features:\n                raise ValueError(f\"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}\")\n\n            mask = input_ids == self.config.video_token_id\n            mask_unsqueezed = mask.unsqueeze(-1)\n            mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)\n            video_mask = mask_expanded.to(inputs_embeds.device)\n\n            video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)\n            inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)\n\n        if attention_mask is not None:\n            attention_mask = attention_mask.to(inputs_embeds.device)\n\n    # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme\n    if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):\n        # calculate RoPE index once per generation in the pre-fill stage only\n        if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:\n            position_ids, rope_deltas = self.get_rope_index(\n                input_ids,\n                image_grid_thw,\n                video_grid_thw,\n                second_per_grid_ts,\n                attention_mask,\n            )\n            self.rope_deltas = rope_deltas\n        # then use the prev pre-calculated rope-deltas to get the correct position ids\n        else:\n            batch_size, seq_length, _ = inputs_embeds.shape\n            delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0\n            position_ids = torch.arange(seq_length, device=inputs_embeds.device)\n            position_ids = position_ids.view(1, -1).expand(batch_size, -1)\n            if cache_position is not None:  # otherwise `deltas` is an int `0`\n                delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)\n            position_ids = position_ids.add(delta)\n            position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)\n\n    outputs = self.model(\n        input_ids=None,\n        position_ids=position_ids,\n        attention_mask=attention_mask,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        return_dict=return_dict,\n        cache_position=cache_position,\n    )\n\n    hidden_states = outputs[0]\n\n    if not return_dict:\n        raise NotImplementedError(\"forward_for_ppo has to return_dict\")\n\n    # Loss calculations\n    if labels is not None:\n        rolled_labels = torch.roll(labels, shifts=-1, dims=-1)\n    elif input_ids is not None:\n        rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)\n    else:\n        raise RuntimeError(\"To use forward_for_ppo, either labels or input_ids must be provided.\")\n\n    fused_linear_for_ppo = FusedLinearForPPO()\n    log_probs, entropy = fused_linear_for_ppo.forward(\n        hidden_states=hidden_states,\n        vocab_weights=self.lm_head.weight,\n        input_ids=rolled_labels,\n        temperature=temperature,\n    )\n\n    return Qwen2_5_VLCausalLMOutputForPPO(\n        log_probs=log_probs,\n        entropy=entropy,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n        rope_deltas=rope_deltas,\n    )\n"
  },
  {
    "path": "siirl/models/transformers/qwen2_vl.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport os\nfrom dataclasses import dataclass\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward\nfrom transformers.models.qwen2_vl.modeling_qwen2_vl import (\n    Qwen2VLCausalLMOutputWithPast,\n    Qwen2VLForConditionalGeneration,\n)\nfrom transformers.utils import is_flash_attn_greater_or_equal\n\nfrom siirl.utils.model_utils.ulysses import (\n    gather_heads_scatter_seq,\n    gather_seq_scatter_heads,\n    get_ulysses_sequence_parallel_world_size,\n    validate_ulysses_config,\n)\n\nfrom siirl.models.transformers.monkey_patch import is_transformers_version_in_range\n\n# Handle version compatibility for flash_attn_supports_top_left_mask\nfrom siirl.models.transformers.transformers_compat import flash_attn_supports_top_left_mask\n\ntry:\n    from transformers.modeling_flash_attention_utils import flash_attn_func, flash_attn_varlen_func\n\n    _flash_supports_window_size = \"window_size\" in list(inspect.signature(flash_attn_func).parameters)\nexcept ImportError:\n    try:\n        from transformers.modeling_flash_attention_utils import _lazy_imports\n        flash_attn_func, flash_attn_varlen_func, *_ = _lazy_imports(None)\n    except ImportError or ValueError:\n        from flash_attn import flash_attn_func, flash_attn_varlen_func\n\n    _flash_supports_window_size = None\n\n\ndef get_rope_index(\n    processor,\n    input_ids: torch.Tensor,\n    image_grid_thw: Optional[torch.Tensor] = None,\n    video_grid_thw: Optional[torch.Tensor] = None,\n    second_per_grid_ts: Optional[torch.Tensor] = None,\n    attention_mask: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    \"\"\"\n    Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence.\n    The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.\n    https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546\n    \"\"\"\n    spatial_merge_size = processor.image_processor.merge_size\n    tokens_per_second = 2\n    image_token_id = processor.tokenizer.convert_tokens_to_ids(\"<|image_pad|>\")\n    video_token_id = processor.tokenizer.convert_tokens_to_ids(\"<|video_pad|>\")\n    vision_start_token_id = processor.tokenizer.convert_tokens_to_ids(\"<|vision_start|>\")\n    if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):\n        if attention_mask is None:\n            attention_mask = torch.ones_like(input_ids)\n\n        position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device)  # (3, seqlen)\n        image_index, video_index = 0, 0\n        input_ids = input_ids[attention_mask == 1]\n        image_nums, video_nums = 0, 0\n        vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)\n        vision_tokens = input_ids[vision_start_indices + 1]\n        image_nums = (vision_tokens == image_token_id).sum()\n        video_nums = (vision_tokens == video_token_id).sum()\n        input_tokens = input_ids.tolist()\n        llm_pos_ids_list: list = []\n        st = 0\n        remain_images, remain_videos = image_nums, video_nums\n        for _ in range(image_nums + video_nums):\n            if image_token_id in input_tokens and remain_images > 0:\n                ed_image = input_tokens.index(image_token_id, st)\n            else:\n                ed_image = len(input_tokens) + 1\n            if video_token_id in input_tokens and remain_videos > 0:\n                ed_video = input_tokens.index(video_token_id, st)\n            else:\n                ed_video = len(input_tokens) + 1\n            if ed_image < ed_video:\n                t, h, w = (\n                    image_grid_thw[image_index][0],\n                    image_grid_thw[image_index][1],\n                    image_grid_thw[image_index][2],\n                )\n                second_per_grid_t = 0\n                image_index += 1\n                remain_images -= 1\n                ed = ed_image\n            else:\n                t, h, w = (\n                    video_grid_thw[video_index][0],\n                    video_grid_thw[video_index][1],\n                    video_grid_thw[video_index][2],\n                )\n                second_per_grid_t = second_per_grid_ts[video_index] if second_per_grid_ts is not None else 1.0\n\n                video_index += 1\n                remain_videos -= 1\n                ed = ed_video\n\n            llm_grid_t, llm_grid_h, llm_grid_w = (\n                t.item(),\n                h.item() // spatial_merge_size,\n                w.item() // spatial_merge_size,\n            )\n            text_len = ed - st\n\n            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0\n            llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)\n\n            t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)\n            t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten()\n            h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()\n            w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()\n            llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)\n            st = ed + llm_grid_t * llm_grid_h * llm_grid_w\n\n        if st < len(input_tokens):\n            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0\n            text_len = len(input_tokens) - st\n            llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)\n\n        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)\n        position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)\n    else:\n        if attention_mask is not None:\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device)\n        else:\n            position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1)\n\n    return position_ids\n\n\ndef prepare_fa2_from_position_ids(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor):\n    query = query.view(-1, query.size(-2), query.size(-1))\n    key = key.view(-1, key.size(-2), key.size(-1))\n    value = value.view(-1, value.size(-2), value.size(-1))\n    position_ids = position_ids.flatten()\n    indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)\n    cu_seqlens = torch.cat(\n        (\n            indices_q[position_ids == 0],\n            torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),\n        )\n    )\n    max_length = cu_seqlens.diff().max()  # use cu_seqlens to infer max_length for qwen2vl mrope\n    return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length))\n\n\ndef flash_attention_forward(\n    query_states: torch.Tensor,\n    key_states: torch.Tensor,\n    value_states: torch.Tensor,\n    attention_mask: torch.Tensor,\n    query_length: int,\n    is_causal: bool = True,\n    position_ids: Optional[torch.Tensor] = None,\n    sliding_window: Optional[int] = None,\n    use_top_left_mask: bool = False,\n    deterministic: Optional[bool] = None,\n    **kwargs,\n):\n    \"\"\"\n    Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)\n    \"\"\"\n    causal = is_causal if not use_top_left_mask else is_causal and query_length != 1\n\n    # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).\n    use_sliding_windows = _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window\n    flash_kwargs = {\"window_size\": (sliding_window, sliding_window)} if use_sliding_windows else {}\n\n    if is_flash_attn_greater_or_equal(\"2.4.1\"):\n        if deterministic is None:\n            deterministic = os.environ.get(\"FLASH_ATTENTION_DETERMINISTIC\", \"0\") == \"1\"\n        flash_kwargs[\"deterministic\"] = deterministic\n\n    if position_ids is not None and query_length != 1 and not (torch.diff(position_ids[0], dim=-1) >= 0).all():\n        batch_size = query_states.size(0)\n        query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids[0])  # remove channel dimension\n        cu_seqlens_q, cu_seqlens_k = cu_seq_lens\n        max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens\n        attn_output = flash_attn_varlen_func(\n            query_states,\n            key_states,\n            value_states,\n            cu_seqlens_q=cu_seqlens_q,\n            cu_seqlens_k=cu_seqlens_k,\n            max_seqlen_q=max_seqlen_in_batch_q,\n            max_seqlen_k=max_seqlen_in_batch_k,\n            dropout_p=kwargs.pop(\"dropout\", 0.0),\n            softmax_scale=kwargs.pop(\"softmax_scale\", None),\n            causal=causal,\n            **flash_kwargs,\n        )\n        attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))\n    else:\n        attn_output = _flash_attention_forward(\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            query_length,\n            is_causal=is_causal,\n            sliding_window=sliding_window,\n            use_top_left_mask=flash_attn_supports_top_left_mask(),\n            deterministic=deterministic,\n            **kwargs,\n        )  # do not pass position_ids to old flash_attention_forward\n\n    return attn_output\n\n\ndef ulysses_flash_attn_forward(\n    self,\n    hidden_states: torch.Tensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46\n    **kwargs,\n) -> Tuple[torch.Tensor, None, None]:\n    from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb, repeat_kv\n\n    bsz, q_len, _ = hidden_states.size()  # q_len = seq_length / sp_size\n    query_states = self.q_proj(hidden_states)  # (batch_size, seq_length / sp_size, num_heads * head_size)\n    key_states = self.k_proj(hidden_states)\n    value_states = self.v_proj(hidden_states)\n\n    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n    ulysses_sp_size = get_ulysses_sequence_parallel_world_size()\n\n    if ulysses_sp_size > 1:\n        validate_ulysses_config(self.num_heads, ulysses_sp_size)\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n        query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)\n        key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)\n        value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)\n        # (batch_size, num_head / sp_size, seq_length, head_size)\n        full_q_len = query_states.size(2)  # full_q_len = seq_length\n    else:\n        full_q_len = q_len\n\n    # Because the input can be padded, the absolute sequence length depends on the max position id.\n    if position_embeddings is None:\n        cos, sin = self.rotary_emb(value_states, position_ids)\n    else:\n        cos, sin = position_embeddings\n\n    query_states, key_states = apply_multimodal_rotary_pos_emb(query_states, key_states, cos, sin, self.rope_scaling[\"mrope_section\"])\n    dropout_rate = 0.0 if not self.training else self.attention_dropout\n\n    # Reashape to the expected shape for Flash Attention\n    query_states = query_states.transpose(1, 2)\n    key_states = key_states.transpose(1, 2)\n    value_states = value_states.transpose(1, 2)\n\n    if self.config.use_sliding_window and getattr(self.config, \"sliding_window\", None) is not None and self.layer_idx >= self.config.max_window_layers:\n        sliding_window = self.config.sliding_window\n    else:\n        sliding_window = None\n\n    attn_output = flash_attention_forward(\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        full_q_len,\n        dropout=dropout_rate,\n        sliding_window=sliding_window,\n        is_causal=self.is_causal,\n        use_top_left_mask=flash_attn_supports_top_left_mask(),\n        position_ids=position_ids,  # important: pass position ids\n    )  # (batch_size, seq_length, num_head / sp_size, head_size)\n    if ulysses_sp_size > 1:\n        attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)\n\n    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()\n    attn_output = self.o_proj(attn_output)\n    if is_transformers_version_in_range(min_version=\"4.53.0\"):\n        return attn_output, None\n    else:\n        return attn_output, None, None\n\n\n@dataclass\nclass Qwen2VLCausalLMOutputForPPO(Qwen2VLCausalLMOutputWithPast):\n    log_probs: Optional[torch.FloatTensor] = None\n    entropy: Optional[torch.FloatTensor] = None\n\n\ndef forward_for_ppo(\n    self: Qwen2VLForConditionalGeneration,\n    input_ids: torch.LongTensor = None,\n    attention_mask: Optional[torch.Tensor] = None,\n    position_ids: Optional[torch.LongTensor] = None,\n    past_key_values: Optional[List[torch.FloatTensor]] = None,\n    inputs_embeds: Optional[torch.FloatTensor] = None,\n    labels: Optional[torch.LongTensor] = None,\n    use_cache: Optional[bool] = None,\n    output_attentions: Optional[bool] = None,\n    output_hidden_states: Optional[bool] = None,\n    return_dict: Optional[bool] = None,\n    pixel_values: Optional[torch.Tensor] = None,\n    pixel_values_videos: Optional[torch.FloatTensor] = None,\n    image_grid_thw: Optional[torch.LongTensor] = None,\n    video_grid_thw: Optional[torch.LongTensor] = None,\n    rope_deltas: Optional[torch.LongTensor] = None,\n    cache_position: Optional[torch.LongTensor] = None,\n    temperature: float = 1.0,\n    **loss_kwargs,\n) -> Union[Tuple, Qwen2VLCausalLMOutputForPPO]:\n    r\"\"\"\n    Copy paste Qwen2VL's forward\n    https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_vl.py\n    ```\"\"\"\n    from siirl.utils.experimental.torch_functional import FusedLinearForPPO\n\n    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n    output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n    return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n    if inputs_embeds is None:\n        inputs_embeds = self.model.embed_tokens(input_ids)\n        if pixel_values is not None:\n            pixel_values = pixel_values.type(self.visual.get_dtype())\n            image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)\n            n_image_tokens = (input_ids == self.config.image_token_id).sum().item()\n            n_image_features = image_embeds.shape[0]\n            if n_image_tokens != n_image_features:\n                raise ValueError(f\"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}\")\n            image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)\n            image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)\n            inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)\n\n        if pixel_values_videos is not None:\n            pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())\n            video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)\n            n_video_tokens = (input_ids == self.config.video_token_id).sum().item()\n            n_video_features = video_embeds.shape[0]\n            if n_video_tokens != n_video_features:\n                raise ValueError(f\"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}\")\n            video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)\n            video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)\n            inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)\n\n        if attention_mask is not None:\n            attention_mask = attention_mask.to(inputs_embeds.device)\n\n    if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):\n        # calculate RoPE index once per generation in the pre-fill stage only\n        if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:\n            position_ids, rope_deltas = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)\n            self.rope_deltas = rope_deltas\n        # then use the prev pre-calculated rope-deltas to get the correct position ids\n        else:\n            batch_size, seq_length, _ = inputs_embeds.shape\n            delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0\n            position_ids = torch.arange(seq_length, device=inputs_embeds.device)\n            position_ids = position_ids.view(1, -1).expand(batch_size, -1)\n            if cache_position is not None:  # otherwise `deltas` is an int `0`\n                delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)\n            position_ids = position_ids.add(delta)\n            position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)\n\n    outputs = self.model(\n        input_ids=None,\n        position_ids=position_ids,\n        attention_mask=attention_mask,\n        past_key_values=past_key_values,\n        inputs_embeds=inputs_embeds,\n        use_cache=use_cache,\n        output_attentions=output_attentions,\n        output_hidden_states=output_hidden_states,\n        return_dict=return_dict,\n        cache_position=cache_position,\n    )\n\n    hidden_states = outputs[0]\n\n    if not return_dict:\n        raise NotImplementedError(\"forward_for_ppo has to return_dict\")\n\n    # Loss calculations\n    if labels is not None:\n        rolled_labels = torch.roll(labels, shifts=-1, dims=-1)\n    elif input_ids is not None:\n        rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)\n    else:\n        raise RuntimeError(\"To use forward_for_ppo, either labels or input_ids must be provided.\")\n\n    fused_linear_for_ppo = FusedLinearForPPO()\n    log_probs, entropy = fused_linear_for_ppo.forward(\n        hidden_states=hidden_states,\n        vocab_weights=self.lm_head.weight,\n        input_ids=rolled_labels,\n        temperature=temperature,\n    )\n\n    return Qwen2VLCausalLMOutputForPPO(\n        log_probs=log_probs,\n        entropy=entropy,\n        past_key_values=outputs.past_key_values,\n        hidden_states=outputs.hidden_states,\n        attentions=outputs.attentions,\n        rope_deltas=rope_deltas,\n    )\n"
  },
  {
    "path": "siirl/models/transformers/transformers_compat.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nCompatibility utilities for different versions of transformers library.\n\"\"\"\n\n# Handle version compatibility for flash_attn_supports_top_left_mask\n# This function was added in newer versions of transformers\ntry:\n    from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask\nexcept ImportError:\n    # For older versions of transformers that don't have this function\n    # Default to False as a safe fallback for older versions\n    def flash_attn_supports_top_left_mask():\n        \"\"\"Fallback implementation for older transformers versions.\n        Returns False to disable features that require this function.\n        \"\"\"\n        return False\n"
  },
  {
    "path": "siirl/models/weight_loader_registry.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\ndef get_weight_loader(arch: str):\n    from siirl.models.mcore.loader import load_state_dict_to_megatron_gptmodel\n\n    _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = {\n        \"LlamaForCausalLM\": load_state_dict_to_megatron_gptmodel,\n        \"Qwen2ForCausalLM\": load_state_dict_to_megatron_gptmodel,\n    }\n\n    if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY:\n        return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch]\n    raise ValueError(f\"Model architectures {arch} loader are not supported for now. Supported architectures: {_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}\")\n\n\ndef get_weight_saver(arch: str):\n    from siirl.models.mcore.saver import merge_megatron_ckpt_gptmodel, merge_megatron_ckpt_gptmodel_dpskv3, merge_megatron_ckpt_gptmodel_mixtral, merge_megatron_ckpt_gptmodel_qwen_moe\n\n    _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY = {\n        \"LlamaForCausalLM\": merge_megatron_ckpt_gptmodel,\n        \"Qwen2ForCausalLM\": merge_megatron_ckpt_gptmodel,\n        \"MixtralForCausalLM\": merge_megatron_ckpt_gptmodel_mixtral,\n        \"Qwen2MoeForCausalLM\": merge_megatron_ckpt_gptmodel_qwen_moe,\n        \"DeepseekV3ForCausalLM\": merge_megatron_ckpt_gptmodel_dpskv3,\n        \"Qwen3ForCausalLM\": merge_megatron_ckpt_gptmodel,\n        \"Qwen3MoeForCausalLM\": merge_megatron_ckpt_gptmodel_qwen_moe,\n    }\n    if arch in _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY:\n        return _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY[arch]\n    raise ValueError(f\"Model architectures {arch} saver are not supported for now. Supported architectures: {_MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY.keys()}\")\n"
  },
  {
    "path": "siirl/params/__init__.py",
    "content": "from .data_args import DataArguments\nfrom .model_args import (\n    ModelArguments,\n    ActorRolloutRefArguments,\n    CriticArguments,\n    RewardModelArguments,\n    AlgorithmArguments,\n    ActorArguments,\n    RolloutArguments,\n    RefArguments,\n)\nfrom .training_args import TrainingArguments, SiiRLArguments\nfrom .parser import parse_config\nfrom .display_dict import log_dict_formatted\nfrom .profiler_args import ProfilerArguments\n\n__all__ = [\n    \"ActorRolloutRefArguments\",\n    \"CriticArguments\",\n    \"RewardModelArguments\",\n    \"AlgorithmArguments\",\n    \"DataArguments\",\n    \"ModelArguments\",\n    \"TrainingArguments\",\n    \"SiiRLArguments\",\n    \"ActorArguments\",\n    \"RefArguments\",\n    \"RolloutArguments\",\n    \"ProfilerArguments\",\n    \"parse_config\",\n    \"log_dict_formatted\",\n]\n"
  },
  {
    "path": "siirl/params/dag_args.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import asdict, dataclass, field\nfrom typing import Optional\n\n\n@dataclass\nclass DagArguments:\n    workflow_path: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"Workflow DAG config file (YAML, legacy mode). Consider using custom_pipeline_fn instead.\"}\n    )\n    custom_pipeline_fn: Optional[str] = field(\n        default=None,\n        metadata={\n            \"help\": \"Custom pipeline function path in format 'module.path:function_name'. \"\n                    \"Example: 'examples.custom_pipeline_example.custom_grpo:grpo_with_custom_reward'. \"\n                    \"If not specified, built-in pipelines will be used based on algorithm type.\"\n        }\n    )\n    env_enable: bool = field(default=False, metadata={\"help\": \"Enable environment\"})\n    environment_path: Optional[str] = field(default=None, metadata={\"help\": \"Environment config file\"})\n    enable_perf: bool = field(default=False, metadata={\"help\": \"Enable all ranks performance profiling table\"})\n    backend_threshold: int = field(default=256, metadata={\"help\": \"World size threshold for backend selection\"})\n"
  },
  {
    "path": "siirl/params/data_args.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom dataclasses import asdict, dataclass, field\nfrom typing import Any, Dict, List, Literal, Optional\n\nfrom siirl.params.model_args import ProcessorArguments\n\n\n@dataclass\nclass DataArguments:\n    dataset_type: str = field(\n        default=\"llm\",\n        metadata={\"help\": \"Type of dataset, e.g., 'llm' for traditional prompt-based datasets, \"\n                          \"or 'vla' for Vision-Language-Action task manifests.\"}\n    )\n    tokenizer: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"Tokenizer configuration (null for auto-detection)\"},\n    )\n    train_files: List[str] = field(\n        default_factory=lambda: [\"~/data/rlhf/gsm8k/train.parquet\"],\n        metadata={\"help\": \"Training dataset path\"},\n    )\n    val_files: List[str] = field(\n        default_factory=lambda: [\"~/data/rlhf/gsm8k/test.parquet\"],\n        metadata={\"help\": \"Validation dataset path\"},\n    )\n    prompt_key: str = field(default=\"prompt\", metadata={\"help\": \"Dataset column name for prompts\"})\n    max_prompt_length: int = field(default=512, metadata={\"help\": \"Max token length for prompts\"})\n    max_response_length: int = field(default=512, metadata={\"help\": \"Max token length for responses\"})\n    train_batch_size: int = field(default=1024, metadata={\"help\": \"Training batch size\"})\n    gen_batch_size: Optional[int] = field(default=None, metadata={\"help\": \"Generation batch size for DAPO (typically 3x train_batch_size)\"})\n    val_batch_size: Optional[int] = field(default=None, metadata={\"help\": \"[Deprecated] Validation batch handling\"})\n    return_raw_input_ids: bool = field(default=False, metadata={\"help\": \"Return raw token IDs\"})\n    return_raw_chat: bool = field(default=False, metadata={\"help\": \"Return unprocessed chat data\"})\n    return_full_prompt: bool = field(default=False, metadata={\"help\": \"Whether to return the full prompt with chat template\"})\n    filter_overlong_prompts: bool = field(default=False, metadata={\"help\": \"For large-scale dataset, filtering overlong prompts could be timeconsuming.\"})\n    shuffle: bool = field(default=True, metadata={\"help\": \"Shuffle training data\"})\n    image_key: str = field(default=\"images\", metadata={\"help\": \"Dataset column name for images\"})\n    video_key: str = field(default=\"videos\", metadata={\"help\": \"Dataset column name for videos\"})\n    truncation: str = field(\n        default=\"error\",\n        metadata={\"help\": \"Truncate the input_ids or prompt length if they exceed max_prompt_length. Default is 'error', not allow exceed the max_prompt_length. The users should increase the max_prompt_length if throwing the error. You can also set ``left`` ``middle`` and ``right``\"},\n    )\n    train_on_prompt: bool = field(\n        default=False,\n        metadata={\"help\": \"Whether or not to disable the mask on the prompt.\"},\n    )\n    mask_history: bool = field(\n        default=False,\n        metadata={\"help\": \"Whether or not to mask the history and train on the last turn only.\"},\n    )\n    tool_format: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"Tool format to use for constructing function calling examples.\"},\n    )\n    tokenized_path: Optional[str] = field(\n        default=None,\n        metadata={\"help\": (\"Path to save or load the tokenized datasets. If tokenized_path not exists, it will save the tokenized datasets. If tokenized_path exists, it will load the tokenized datasets.\")},\n    )\n    dataset_cache_dir: str = field(\n        default=\"/tmp/.cache/siirl/rlhf\",\n        metadata={\"help\": \"Local cache directory for rlhf dataset.\"},\n    )\n    filter_overlong_prompt: bool = field(\n        default=False,\n        metadata={\"help\": \"Whether or not to filter prompt which length > max_prompt_length for dataset.\"},\n    )\n    serialize_dataset: bool = field(\n        default=False,\n        metadata={\"help\": \"Whether or not to store serialize dataset in state_dict.\"},\n    )\n    streaming: bool = field(\n        default=False,\n        metadata={\"help\": \"Enable dataset streaming.\"},\n    )\n    preprocessing_num_workers: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"The number of processes to use for the pre-processing.\"},\n    )\n    overwrite_cache: bool = field(\n        default=False,\n        metadata={\"help\": \"Overwrite the cached training and evaluation sets.\"},\n    )\n    preprocessing_batch_size: int = field(\n        default=1000,\n        metadata={\"help\": \"The number of examples in one group in pre-processing.\"},\n    )\n    mix_strategy: Literal[\"concat\", \"interleave_under\", \"interleave_over\"] = field(\n        default=\"concat\",\n        metadata={\"help\": \"Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling).\"},\n    )\n    interleave_probs: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"Probabilities to sample data from datasets. Use commas to separate multiple datasets.\"},\n    )\n    buffer_size: int = field(\n        default=16384,\n        metadata={\"help\": \"Size of the buffer to randomly sample examples from in dataset streaming.\"},\n    )\n    cutoff_len: int = field(\n        default=2048,\n        metadata={\"help\": \"The cutoff length of the tokenized inputs in the dataset.\"},\n    )\n    reward_fn_key: str = field(default=\"data_source\", metadata={\"help\": \"reward data source key\"})\n    multi_agent: bool = field(default=False, metadata={\"help\": \"The DAG pipeline is multi agent or not\"})\n    auto_repeat: bool = field(default=False, metadata={\"help\": \"Automatically repeats the training dataset. Recommended when the number of samples is smaller than the total training steps to prevent premature termination.\"})\n    num_loader_workers: int = field(default=8, metadata={\"help\": \"DataLoader worker number\"})\n    force_on_the_fly: bool = field(default=False, metadata={\"help\": \"If True, the data will be loaded on-the-fly, which is useful for large datasets that cannot fit into memory.\"})\n    processor: ProcessorArguments = field(\n        default_factory=ProcessorArguments,\n        metadata={\"help\": \"Arguments for the processor.\"},\n    )\n\n    def __post_init__(self):\n        def split_arg(arg):\n            if isinstance(arg, str):\n                return [item.strip() for item in arg.split(\",\")]\n            return arg\n\n        self.train_files = split_arg(self.train_files)\n        self.val_files = split_arg(self.val_files)\n        if self.mask_history and self.train_on_prompt:\n            raise ValueError(\"`mask_history` is incompatible with `train_on_prompt`.\")\n        if self.interleave_probs is not None:\n            if self.mix_strategy == \"concat\":\n                raise ValueError(\"`interleave_probs` is only valid for interleaved mixing.\")\n\n            self.interleave_probs = list(map(float, split_arg(self.interleave_probs)))\n            if self.train_files is not None and len(self.train_files) != len(self.interleave_probs):\n                raise ValueError(\"The length of dataset and interleave probs should be identical.\")\n\n            if self.val_files is not None and len(self.val_files) != len(self.interleave_probs):\n                raise ValueError(\"The length of eval dataset and interleave probs should be identical.\")\n\n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n"
  },
  {
    "path": "siirl/params/display_dict.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nfrom loguru import logger\nfrom typing import Dict, Any, Optional, List, Tuple\n\n# --- Formatting Constants ---\nBASE_INDENT_UNIT_FOR_LOGGING = \"  \"\nDESIRED_MIN_VARIABLE_DOTS_FOR_FILLER = 1  # Minimum dots in the \"...\" part for simple values\nDEFAULT_HEADER_TEXT_LOGGING = \"details\"\nTARGET_HEADER_TOTAL_WIDTH_LOGGING = 80  # Adjusted for wider output\nTARGET_VALUE_ALIGNMENT_COLUMN_LOGGING = 80  # Target column for value alignment for simple values, lists, and sets.\n\n\ndef _render_dict_recursively_util(current_dict_to_render: Dict[str, Any], current_indent_str: str, fixed_value_align_col: int, base_indent_unit: str, lines: list):\n    \"\"\"\n    Internal recursive helper to render dictionary content.\n    Dictionaries are expanded. Lines announcing a dictionary now end with a colon.\n    Lists and sets are printed as single-line strings aligned to fixed_value_align_col.\n    Simple values are also aligned to fixed_value_align_col.\n    \"\"\"\n    try:\n        sorted_items: List[Tuple[str, Any]] = sorted([(str(k), v) for k, v in current_dict_to_render.items()])\n    except Exception as e:\n        lines.append(f\"{current_indent_str}[Could not sort keys: {e}]\")\n        sorted_items = [(str(k), v) for k, v in current_dict_to_render.items()]\n\n    if not sorted_items and current_indent_str != base_indent_unit:\n        lines.append(f\"{current_indent_str}(empty dict)\")\n        return\n\n    for key_s, value_obj in sorted_items:\n        prefix_key_only = f\"{current_indent_str}{key_s}\"\n\n        if isinstance(value_obj, dict):\n            lines.append(f\"{prefix_key_only}:\")\n            _render_dict_recursively_util(value_obj, current_indent_str + base_indent_unit, fixed_value_align_col, base_indent_unit, lines)\n        else:\n            if isinstance(value_obj, list):\n                try:\n                    val_s = json.dumps(value_obj, separators=(\",\", \":\"), ensure_ascii=False)\n                except TypeError:\n                    val_s = str(value_obj)\n            elif isinstance(value_obj, set):\n                try:\n                    val_s = json.dumps(sorted(list(value_obj)), separators=(\",\", \":\"), ensure_ascii=False)\n                except TypeError:\n                    val_s = str(value_obj)\n            else:\n                val_s = str(value_obj)\n\n            prefix_for_dots_alignment = f\"{prefix_key_only} .\"\n            suffix_part_len = len(\". \") + len(val_s)\n            dots_needed = fixed_value_align_col - len(prefix_for_dots_alignment) - suffix_part_len\n            dots = \".\" * max(DESIRED_MIN_VARIABLE_DOTS_FOR_FILLER, dots_needed)\n            lines.append(f\"{prefix_for_dots_alignment}{dots}. {val_s}\")\n\n\ndef log_dict_formatted(data_dict: Dict[str, Any], title: Optional[str] = \"Configuration Details\", header_text_content: str = DEFAULT_HEADER_TEXT_LOGGING, target_value_alignment_column: int = TARGET_VALUE_ALIGNMENT_COLUMN_LOGGING, log_level: str = \"info\"):\n    \"\"\"\n    Logs a dictionary with hierarchical indentation for nested dictionaries,\n    styled similarly to Megatron-LM argument printing. Uses loguru.\n    Lines announcing a nested dictionary end with a colon.\n    Lists, sets, and simple values are aligned to target_value_alignment_column.\n\n    Args:\n        data_dict (Dict[str, Any]): The dictionary to log.\n        title (Optional[str]): A title for this configuration block.\n        header_text_content (str): Text to use in the header/footer lines.\n        target_value_alignment_column (int): The column index where simple values/lists/sets should start.\n    \"\"\"\n    if not isinstance(data_dict, dict):\n        logger.error(f\"Invalid input: data_dict must be a dictionary. Received type: {type(data_dict)}\")\n        return\n\n    current_target_header_width = max(TARGET_HEADER_TOTAL_WIDTH_LOGGING, target_value_alignment_column + 10)\n    num_spaces_in_header = 2\n    padding_dashes_total = current_target_header_width - len(header_text_content) - num_spaces_in_header\n    if padding_dashes_total < 0:\n        padding_dashes_total = 0\n\n    dashes_left = padding_dashes_total // 2\n    dashes_right = padding_dashes_total - dashes_left\n    header_line = f\"{'-' * dashes_left} {header_text_content} {'-' * dashes_right}\"\n\n    lines = []\n    if title:\n        lines.append(f\"## {title} ##\")\n    lines.append(header_line)\n\n    if not data_dict:\n        lines.append(f\"{BASE_INDENT_UNIT_FOR_LOGGING}(No items in dictionary)\")\n    else:\n        _render_dict_recursively_util(data_dict, BASE_INDENT_UNIT_FOR_LOGGING, target_value_alignment_column, BASE_INDENT_UNIT_FOR_LOGGING, lines)\n\n    lines.append(\"-\" * len(header_line))\n    lines.append(\"\")\n\n    # Log all lines as a single multi-line message\n    valid_levels = [\"INFO\", \"DEBUG\", \"WARNING\", \"ERROR\", \"CRITICAL\"]\n    if log_level.upper() not in valid_levels:\n        raise ValueError(f\"Invalid log level: {log_level}. Choose from {valid_levels}.\")\n    logger.log(log_level.upper(), \"\\n\".join(lines))\n"
  },
  {
    "path": "siirl/params/embodied_args.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import asdict, dataclass, field\nfrom typing import Any, Dict, List, Literal, Optional\n\n\n@dataclass\nclass EnvironmentArgs:\n    \"\"\"Unified configuration for Embodied AI environments.\"\"\"\n    env_type: str = field(\n        default=\"libero\",\n        metadata={\"help\": \"Environment type: 'libero' or 'maniskill'\"}\n    )\n    env_name: str = field(\n        default=\"libero_10\",\n        metadata={\"help\": \"Name of the specific environment or task suite to load (e.g., 'libero_spatial', 'PickCube-v1')\"}\n    )\n    num_envs: int = field(\n        default=16,\n        metadata={\"help\": \"Number of parallel environments to run\"}\n    )\n    max_steps: int = field(\n        default=512,\n        metadata={\"help\": \"Maximum number of steps per episode\"}\n    )\n    num_trials_per_task: int = field(\n        default=50,\n        metadata={\"help\": \"Number of trials per task for dataset preparation\"}\n    )\n    num_steps_wait: int = field(\n        default=10,\n        metadata={\"help\": \"Number of steps to wait before polling for environment completion\"}\n    )\n    model_family: str = field(\n        default=\"openvla\",\n        metadata={\"help\": \"Model family for environment interaction (e.g., 'openvla')\"}\n    )\n\n\n@dataclass\nclass EmbodiedArguments:\n    \"\"\"Embodied AI-specific configuration for Vision-Language-Action models.\"\"\"\n\n    # Unified Environment Configuration\n    env: EnvironmentArgs = field(\n        default_factory=EnvironmentArgs,\n        metadata={\"help\": \"Unified environment configuration for Embodied AI tasks\"}\n    )\n\n    embodied_type: str = field(\n        default=\"openvla\", \n        metadata={\"help\": \"Embodied model type: 'openvla' or 'openvla-oft'\"}\n    )\n    model_path: str = field(\n        default=\"openvla/openvla-7b\", metadata={\"help\": \"Path to Embodied AI model\"}\n    )\n    video_embedding_model_path: str = field(\n        default=\"~/models/vjepa/vitg-384.pt\",\n        metadata={\"help\": \"Path to V-JEPA embedding model\"},\n    )\n\n    action_chunks_len: int = field(\n        default=8, metadata={\"help\": \"Number of action chunks per step\"}\n    )\n    action_token_len: int = field(\n        default=7, metadata={\"help\": \"Number of action tokens\"}\n    )\n\n    # Image processing\n    embedding_img_size: int = field(\n        default=384, metadata={\"help\": \"Image size for video embedding\"}\n    )\n    embedding_enable_fp16: bool = field(\n        default=True, metadata={\"help\": \"Enable FP16 for video embedding\"}\n    )\n    # Embedding model configuration\n    embedding_model_class: Optional[str] = field(\n        default=None, metadata={\"help\": \"Custom embedding model class path\"}\n    )\n    embedding_model_offload: bool = field(\n        default=False,\n        metadata={\"help\": \"Offload embedding model to CPU when not in use\"},\n    )\n    num_images_in_input: int = field(\n        default=1, metadata={\"help\": \"Number of camera views (1=main, 2=main+wrist)\"}\n    )\n    center_crop: bool = field(\n        default=True, metadata={\"help\": \"Apply center crop to images\"}\n    )\n    # Action normalization stats (will be populated from config)\n    unnorm_key: str = field(\n        default=\"libero_10\",\n        metadata={\"help\": \"Key for action normalization stats used in un-normalization\"},\n    )\n    # Generation parameters\n    temperature: float = field(\n        default=1.6, metadata={\"help\": \"Sampling temperature for action generation\"}\n    )\n\n    n_gpus_per_node: int = field(\n        default=8,\n        metadata={\"help\": \"Number of GPUs per node\"}\n    )\n    \n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n\n@dataclass\nclass EmbodiedSamplingConfig:\n    \"\"\"Configuration for embodied AI dynamic sampling and filtering (similar to DAPO filter_groups).\"\"\"\n    \n    filter_accuracy: bool = field(\n        default=False,\n        metadata={\"help\": \"Enable accuracy-based filtering for embodied tasks\"}\n    )\n    accuracy_lower_bound: float = field(\n        default=0.0,\n        metadata={\"help\": \"Minimum success rate threshold for keeping prompts\"}\n    )\n    accuracy_upper_bound: float = field(\n        default=1.0,\n        metadata={\"help\": \"Maximum success rate threshold for keeping prompts\"}\n    )\n    filter_truncated: bool = field(\n        default=False,\n        metadata={\"help\": \"Filter out truncated episodes (uses env.max_steps for truncation detection)\"}\n    )\n    oversample_factor: int = field(\n        default=1,\n        metadata={\"help\": \"Oversample factor for data filtering\"}\n    )\n    \n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n\n\n__all__ = [\n    \"EnvironmentArgs\",\n    \"EmbodiedArguments\",\n    \"EmbodiedSamplingConfig\"\n]\n"
  },
  {
    "path": "siirl/params/model_args.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n# Copyright 2025, Infrawaves. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import asdict, dataclass, field\nfrom typing import Any, Dict, List, Literal, Optional\nfrom .embodied_args import EmbodiedArguments, EmbodiedSamplingConfig\n\n\n@dataclass\nclass MixedPrecisionArguments:\n    param_dtype: Literal[\"float16\", \"bfloat16\", \"float32\"] = field(\n        default=\"bfloat16\",\n        metadata={\"help\": \"Param precision to use for fsdp MixedPrecision model\"},\n    )\n    reduce_dtype: Literal[\"float16\", \"bfloat16\", \"float32\"] = field(\n        default=\"float32\",\n        metadata={\"help\": \"Reduce precision to use for fsdp MixedPrecision model\"},\n    )\n    buffer_dtype: Literal[\"float16\", \"bfloat16\", \"float32\"] = field(\n        default=\"float32\",\n        metadata={\"help\": \"Buffer precision to use for fsdp MixedPrecision model\"},\n    )\n    keep_low_precision_grads: bool = field(default=False, metadata={\"help\": \"Whether or not to use low precision grad\"})\n    cast_forward_inputs: bool = field(default=False, metadata={\"help\": \"Whether or not to cast forward inputs\"})\n    cast_root_forward_inputs: bool = field(\n        default=True, metadata={\"help\": \"Whether or not to cast root forward inputs\"}\n    )\n\n\n@dataclass\nclass FSDPArguments:\n    wrap_policy: Dict[str, Any] = field(\n        default_factory=lambda: {\"min_num_params\": 0},\n        metadata={\"help\": \"Wrapping policy configuration\"},\n    )\n    param_offload: bool = field(default=False, metadata={\"help\": \"Parameter offloading\"})\n    optimizer_offload: bool = field(default=False, metadata={\"help\": \"Optimizer state offloading\"})\n    fsdp_size: int = field(default=-1, metadata={\"help\": \"FSDP group size\"})\n    model_dtype: Literal[\"float16\", \"bfloat16\", \"float32\"] = field(\n        default=\"float32\",\n        metadata={\"help\": \"PrecisionType to use for model\"},\n    )\n    mixed_precision: MixedPrecisionArguments = field(\n        default_factory=MixedPrecisionArguments,\n        metadata={\"help\": \"fsdp mixed precision settings\"},\n    )\n\n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n\n\n@dataclass\nclass MegatronArguments:\n    tensor_model_parallel_size: int = field(default=1, metadata={\"help\": \"Tensor parallelism size\"})\n    pipeline_model_parallel_size: int = field(default=1, metadata={\"help\": \"Pipeline parallelism size\"})\n    context_parallel_size: int = field(default=1, metadata={\"help\": \"Context parallelism size\"})\n    expert_model_parallel_size: int = field(default=1, metadata={\"help\": \"Expert model parallelism size\"})\n    expert_tensor_parallel_size: int = field(default=1, metadata={\"help\": \"Expert tensor parallelism size\"})\n    virtual_pipeline_model_parallel_size: Optional[int] = field(\n        default=None, metadata={\"help\": \"Virtual pipeline model parallel size\"}\n    )\n    sequence_parallel: bool = field(default=False, metadata={\"help\": \"Whether the sequence parallel is enabled.\"})\n    use_distributed_optimizer: bool = field(\n        default=True,\n        metadata={\"help\": \"Whether the distributed optimizer is enabled.\"},\n    )\n    param_dtype: str = field(default=\"bfloat16\", metadata={\"help\": \"parameter data dtype\"})\n    seed: int = field(default=1, metadata={\"help\": \"The random seed\"})\n    param_offload: bool = field(default=True, metadata={\"help\": \"Offload parameters to CPU\"})\n    grad_offload: bool = field(default=False, metadata={\"help\": \"Offload gradients to CPU\"})\n    optimizer_offload: bool = field(default=False, metadata={\"help\": \"Offload optimizer states to CPU\"})\n    extra: Dict[str, Any] = field(default_factory=dict, metadata={\"help\": \"Extra settings\"})\n    override_transformer_config: Dict[str, Any] = field(\n        default_factory=dict, metadata={\"help\": \"Override transformer config\"}\n    )\n    use_dist_checkpointing: bool = field(default=False, metadata={\"help\": \"Whether to use distributed checkpointing\"})\n    dist_checkpointing_path: str = field(default=\"\", metadata={\"help\": \"Path to save distributed checkpointing\"})\n    override_ddp_config: Dict[str, Any] = field(default_factory=dict, metadata={\"help\": \"Override ddp config\"})\n    use_mbridge: bool = field(default=False, metadata={\"help\": \"Whether to use mbridge\"})\n\n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n\n\n@dataclass\nclass OptimizerArguments:\n    lr: float = field(default=1e-6, metadata={\"help\": \"Learning rate\"})\n    lr_warmup_steps_ratio: float = field(default=0.0, metadata={\"help\": \"Warmup steps ratio\"})\n    min_lr: float = field(default=0.0, metadata={\"help\": \"Min learning rate\"})\n    min_lr_ratio: Optional[float] = field(default=0.0, metadata={\"help\": \"Min learning rate ratio\"})\n    warmup_style: str = field(default=\"constant\", metadata={\"help\": \"Warmup strategy\"})\n    lr_warmup_init: float = field(default=0.0, metadata={\"help\": \"Learning rate warmup init\"})\n    lr_decay_steps: Optional[int] = field(default=None, metadata={\"help\": \"Learning rate decay steps\"})\n    lr_decay_style: str = field(default=\"linear\", metadata={\"help\": \"Learning rate decay style\"})\n    weight_decay_incr_style: str = field(default=\"constant\", metadata={\"help\": \"Weight decay increase style\"})\n    lr_wsd_decay_style: str = field(default=\"exponential\", metadata={\"help\": \"Learning rate warmup decay style\"})\n    lr_wsd_decay_steps: Optional[int] = field(default=None, metadata={\"help\": \"Learning rate warmup decay steps\"})\n    use_checkpoint_opt_param_scheduler: bool = field(\n        default=False, metadata={\"help\": \"Whether to use checkpoint opt param scheduler\"}\n    )\n    total_training_steps: int = field(default=0, metadata={\"help\": \"Total training steps\"})\n    betas: tuple[float, float] = field(default=(0.9, 0.999), metadata={\"help\": \"Beta params Of Optimizer\"})\n    weight_decay: float = field(default=1e-2, metadata={\"help\": \"Weight decay params of Optimizer\"})\n    lr_warmup_steps: int = field(\n        default=-1,\n        metadata={\"help\": \"Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.\"},\n    )\n    clip_grad: float = field(default=1.0, metadata={\"help\": \"gradient clip\"})\n    num_cycles: float = field(default=0.5, metadata={\"help\": \"num cycles\"})\n    override_optimizer_config: Optional[dict] = field(default=None, metadata={\"help\": \"Override optimizer config\"})\n\n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n\n\n@dataclass\nclass ProcessorArguments:\n    r\"\"\"\n    Arguments pertaining to the image processor.\n    \"\"\"\n\n    image_max_pixels: int = field(\n        default=768 * 768,\n        metadata={\"help\": \"The maximum number of pixels of image inputs.\"},\n    )\n    image_min_pixels: int = field(\n        default=32 * 32,\n        metadata={\"help\": \"The minimum number of pixels of image inputs.\"},\n    )\n    video_max_pixels: int = field(\n        default=256 * 256,\n        metadata={\"help\": \"The maximum number of pixels of video inputs.\"},\n    )\n    video_min_pixels: int = field(\n        default=16 * 16,\n        metadata={\"help\": \"The minimum number of pixels of video inputs.\"},\n    )\n    video_fps: float = field(\n        default=2.0,\n        metadata={\"help\": \"The frames to sample per second for video inputs.\"},\n    )\n    video_maxlen: int = field(\n        default=128,\n        metadata={\"help\": \"The maximum number of sampled frames for video inputs.\"},\n    )\n\n\n@dataclass\nclass ModelArguments(ProcessorArguments):\n    path: str = field(\n        default=\"~/models/deepseek-llm-7b-chat\",\n        metadata={\"help\": \"Model path or identifier\"},\n    )\n    external_lib: Optional[str] = field(default=None, metadata={\"help\": \"External model library\"})\n    override_config: Dict[str, Any] = field(default_factory=dict, metadata={\"help\": \"Model config overrides\"})\n    enable_gradient_checkpointing: bool = field(default=True, metadata={\"help\": \"Gradient checkpointing\"})\n    gradient_checkpointing_kwargs: Dict[str, Any] = field(\n        default_factory=dict, metadata={\"help\": \"Gradient checkpointing kwargs\"}\n    )\n    use_remove_padding: bool = field(default=False, metadata={\"help\": \"Padding removal optimization\"})\n    use_fused_kernels: bool = field(default=False, metadata={\"help\": \"Kernels fuse optimization\"})\n    cache_dir: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"Download from hugging face, modelscope, openmind local cache dir\"},\n    )\n    model_revision: str = field(\n        default=\"main\",\n        metadata={\"help\": \"The specific model version to use (can be a branch name, tag name or commit id).\"},\n    )\n    trust_remote_code: bool = field(\n        default=False,\n        metadata={\"help\": \"Whether to trust the execution of code from datasets/models defined on the Hub or not.\"},\n    )\n    hf_hub_token: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"Auth token to log in with Hugging Face Hub.\"},\n    )\n    use_fast_tokenizer: bool = field(\n        default=True,\n        metadata={\"help\": \"Whether or not to use one of the fast tokenizer (backed by the tokenizers library).\"},\n    )\n    split_special_tokens: bool = field(\n        default=False,\n        metadata={\"help\": \"Whether or not the special tokens should be split during the tokenization process.\"},\n    )\n    model_max_length: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"The maximum input length for model, derived from `cutoff_len`. Do not specify it.\"},\n    )\n    new_special_tokens: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"Special tokens to be added into the tokenizer. Use commas to separate multiple tokens.\"},\n    )\n    resize_vocab: bool = field(\n        default=False,\n        metadata={\"help\": \"Whether or not to resize the tokenizer vocab and the embedding layers.\"},\n    )\n    use_liger: bool = field(\n        default=False,\n        metadata={\"help\": \"Whether or not to apply Liger kernel to the model\"},\n    )\n    fsdp_config: FSDPArguments = field(default_factory=FSDPArguments, metadata={\"help\": \"FSDP settings\"})\n    megatron: MegatronArguments = field(default_factory=MegatronArguments, metadata={\"help\": \"Megatron settings\"})\n    input_tokenizer: Optional[str] = field(default=None, metadata={\"help\": \"input tokenizer path\"})\n    rm_tokenizer: Optional[str] = field(default=None, metadata={\"help\": \"rmokenizer path\"})\n    lora_rank: int = field(default=0, metadata={\"help\": \"set to positive value to enable LoRA (e.g., 32)\"})\n    lora_alpha: float = field(default=16, metadata={\"help\": \"LoRA scaling factor\"})\n    target_modules: str = field(\n        default=\"all-linear\",\n        metadata={\"help\": \"all-linear or [q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj]\"},\n    )\n    use_shm: bool = field(default=False)\n    enable_activation_offload: bool = field(default=False, metadata={\"help\": \"enable activation offload\"})\n    model_type: str = field(default=\"llm\", metadata={\"help\": \"model type\"})\n\n    def __post_init__(self):\n        if self.path is None:\n            raise ValueError(\"Please provide `path`.\")\n\n        if self.split_special_tokens and self.use_fast_tokenizer:\n            raise ValueError(\"`split_special_tokens` is only supported for slow tokenizers.\")\n\n        if self.new_special_tokens is not None:  # support multiple special tokens\n            self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(\",\")]\n\n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n\n\n@dataclass\nclass CheckpointArguments:\n    contents: List[str] = field(\n        default_factory=lambda: [\"model\", \"optimizer\", \"extra\"],  \n        metadata={\"help\": \"The contents to save and load in the checkpoint.\"}\n    )\n    save_contents: List[str] = field(\n        default_factory=lambda: [\"model\", \"optimizer\", \"extra\"],  \n        metadata={\"help\": \"The contents to save in the checkpoint.\"}\n    )\n    load_contents: List[str] = field(\n        default_factory=lambda: [\"model\", \"optimizer\", \"extra\"],  \n        metadata={\"help\": \"The contents to load in the checkpoint.\"}\n    )\n    async_save: bool = field(default=False, metadata={\"help\": \"Async checkpoint save mode\"})\n\n\n@dataclass\nclass PolicyLossArguments:\n    loss_mode: str = field(\n        default=\"vanilla\", metadata={\"help\": \"Loss function mode. Options: 'vanilla', 'clip-cov', 'kl-cov', 'gpg'.\"}\n    )\n    clip_cov_ratio: float = field(default=0.0002, metadata={\"help\": \"Ratio of tokens to be clipped for clip-cov loss.\"})\n    clip_cov_lb: float = field(default=1.0, metadata={\"help\": \"Lower bound for clip-cov loss.\"})\n    clip_cov_ub: float = field(default=5.0, metadata={\"help\": \"Upper bound for clip-cov loss.\"})\n    kl_cov_ratio: float = field(\n        default=0.0002, metadata={\"help\": \"Ratio of tokens to be applied KL penalty for kl-cov loss.\"}\n    )\n    ppo_kl_coef: float = field(default=0.1, metadata={\"help\": \"KL divergence penalty coefficient.\"})\n\n\n@dataclass\nclass ActorArguments:\n    strategy: str = field(default=\"fsdp\", metadata={\"help\": \"Parallel strategy\"})\n    ppo_mini_batch_size: int = field(default=256, metadata={\"help\": \"PPO mini-batch size\"})\n    ppo_micro_batch_size: Optional[int] = field(default=None, metadata={\"help\": \"[Deprecated] Micro-batch size\"})\n    ppo_micro_batch_size_per_gpu: Optional[int] = field(default=1, metadata={\"help\": \"Per-GPU micro-batch size\"})\n    use_dynamic_bsz: bool = field(default=False, metadata={\"help\": \"Dynamic batch sizing\"})\n    ppo_max_token_len_per_gpu: int = field(default=16384, metadata={\"help\": \"Max tokens per GPU\"})\n    grad_clip: float = field(default=1.0, metadata={\"help\": \"Gradient clipping\"})\n    clip_ratio: float = field(default=0.2, metadata={\"help\": \"Clipping ratio\"})\n    clip_ratio_low: float = field(default=0.2, metadata={\"help\": \"Min value for clip ratio\"})\n    clip_ratio_high: float = field(default=0.2, metadata={\"help\": \"Max value for clip ratio\"})\n    clip_ratio_c: float = field(default=3.0, metadata={\"help\": \"lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729\"})\n    entropy_coeff: float = field(default=0, metadata={\"help\": \"Entropy coefficient\"})\n    use_kl_loss: bool = field(default=False, metadata={\"help\": \"Enable KL loss\"})\n    kl_loss_coef: float = field(default=0.001, metadata={\"help\": \"KL loss coefficient\"})\n    kl_loss_type: str = field(default=\"low_var_kl\", metadata={\"help\": \"KL loss type\"})\n    ppo_epochs: int = field(default=1, metadata={\"help\": \"PPO epochs\"})\n    shuffle: bool = field(default=False, metadata={\"help\": \"Data shuffling\"})\n    ulysses_sequence_parallel_size: int = field(default=1, metadata={\"help\": \"Sequence parallel size\"})\n    policy_loss: PolicyLossArguments = field(\n        default_factory=PolicyLossArguments, metadata={\"help\": \"Policy loss settings\"}\n    )\n    tis_imp_ratio_cap: float = field(default=-1, metadata={\"help\": \"Truncated importance sampling ratio cap\"})\n    optim: OptimizerArguments = field(default_factory=OptimizerArguments, metadata={\"help\": \"Optimizer settings\"})\n    fsdp_config: FSDPArguments = field(default_factory=FSDPArguments, metadata={\"help\": \"FSDP settings\"})\n    megatron: MegatronArguments = field(default_factory=MegatronArguments, metadata={\"help\": \"Megatron settings\"})\n    use_remove_padding: bool = field(default=False, metadata={\"help\": \"Padding removal optimization\"})\n    use_fused_kernels: bool = field(default=False, metadata={\"help\": \"Kernels fuse optimization\"})\n    use_torch_compile: bool = field(default=True, metadata={\"help\": \"Whether or not use torch compile\"})\n    checkpoint: CheckpointArguments = field(\n        default_factory=CheckpointArguments, metadata={\"help\": \"Checkpoint configuration\"}\n    )\n    param_offload: bool = field(default=False, metadata={\"help\": \"Enable param offload or not\"})\n    grad_offload: bool = field(default=False, metadata={\"help\": \"Enable grad offload or not\"})\n    optimizer_offload: bool = field(default=False, metadata={\"help\": \"Enable optimizer offload or not\"})\n    load_weight: bool = field(default=True)\n    loss_agg_mode: str = field(default=\"token-mean\", metadata={\"help\": \"seq-mean-token-sum, seq-mean-token-mean\"})\n    recompute_old_log_prob: bool = field(default=True, metadata={\"help\": \"recompute old log prob\"})\n    use_cpgd_loss: bool = field(default=False, metadata={\"help\": \"use cpgd loss\"})\n    policy_drift_coeff: float = field(default=0.0, metadata={\"help\": \"policy drift coeff for CPGD\"})\n    data_loader_seed: Optional[int] = field(default=None, metadata={\"help\": \"Data loader seed\"})\n    profile: dict[str, Any] = field(default_factory=dict, metadata={\"help\": \"Actor Profile settings\"})\n    entropy_checkpointing: bool = field(default=False, metadata={\"help\": \"Enable entropy checkpointing\"})\n    entropy_from_logits_with_chunking: bool = field(\n        default=False, metadata={\"help\": \"Enable entropy from logits with chunking\"}\n    )\n    # Embodied AI parameters (inherited from EmbodiedArguments at runtime)\n    embodied_type: Optional[str] = field(\n        default=None, \n        metadata={\"help\": \"Embodied model type: 'openvla' or 'openvla-oft', inherited from embodied.embodied_type\"}\n    )\n    action_token_len: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"Number of action tokens, inherited from embodied.action_token_len at runtime\"}\n    )\n    action_chunks_len: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"Number of action chunks, inherited from embodied.action_chunks_len at runtime\"}\n    )\n    # Actor-specific training parameters\n    traj_mini_batch_size: int = field(\n        default=16,\n        metadata={\"help\": \"Mini-batch size for trajectory splitting during training (must divide traj_len evenly)\"}\n    )\n\n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n\n\n@dataclass\nclass EvalSamplingArguments:\n    top_k: int = field(default=-1, metadata={\"help\": \"0 for hf rollout, -1 for vllm rollout\"})\n    top_p: float = field(default=1.0)\n    temperature: int = field(default=0)\n    n: int = field(default=1)\n    do_sample: bool = field(default=False)\n\n\n@dataclass\nclass LayerNameMapArguments:\n    qkv_layer_name: str = field(default=\"qkv\", metadata={\"help\": \"QKV layer name map\"})\n    gate_proj_layer_name: str = field(default=\"linear_fc1.weight\", metadata={\"help\": \"Gate projection layer name map\"})\n\n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n\n\n@dataclass\nclass MultiTurnArguments:\n    use_all_traj: bool = field(\n        default=False,\n        metadata={\n            \"help\": \"Set to True for multi-agent generation when trajectories from each round of \"\n            \"multi-turn are needed for training.\"\n        },\n    )\n    enable: bool = field(default=False, metadata={\"help\": \"should set rollout.name to sglang_async if True\"})\n    max_assistant_turns: Optional[int] = field(\n        default=None, metadata={\"help\": \"null for no limit (default max_length // 3)\"}\n    )\n    tool_config_path: Optional[str] = field(default=None, metadata={\"help\": \"null for no tool\"})\n    format: str = field(\n        default=\"hermes\", metadata={\"help\": \"Format of the multi-turn interaction. Options: hermes, llama3_json, ...\"}\n    )\n    tool_config_path: Optional[str] = field(default=None, metadata={\"help\": \" null for no tool\"})\n    max_user_turns: Optional[int] = field(\n        default=None, metadata={\"help\": \"null for no limit (default max_length // 3)\"}\n    )\n    max_parallel_calls: int = field(default=1, metadata={\"help\": \"max parallel call for tools in single turn\"})\n    max_tool_response_length: int = field(default=256, metadata={\"help\": \"max length of tool response\"})\n    tool_response_truncate_side: str = field(\n        default=\"middle\", metadata={\"help\": \"truncate side of tool response: left, middle, right\"}\n    )\n    interaction_config_path: Optional[str] = field(default=None, metadata={\"help\": \"null for no interaction\"})\n    completion_callback: Optional[str] = field(default=None, metadata={\"help\": \"null for default callback\"})\n    use_inference_chat_template: bool = field(\n        default=False,\n        metadata={\n            \"help\": \"- When set to True, the model's default chat template is used for multi-turn rollout, \"\n            \"which typically matches production behavior. \\n \\- When set to False, the token ids recorded for \"\n            \"training are used instead; unlike the default chat template, these always include the model's \"\n            \"full output, \\n \\which may contain additional content such as reasoning content. This maintains \"\n            \"the consistency between training and rollout, but it will lead to longer prompts.\"\n        },\n    )\n    tokenization_sanity_check_mode: str = field(\n        default=\"strict\",\n        metadata={\n            \"help\": \"- disable: disable tokenization sanity check \\n \\\n    - strict: enable strict tokenization sanity check (default) \\n \\\n    - ignore_strippable: ignore strippable tokens when checking tokenization sanity\"\n        },\n    )\n\n\n@dataclass\nclass CustomAsyncServer:\n    path: None\n    # Path to the custom async server implementation\n    name: None\n    # Class name of the custom async server class (e.g. AsyncvLLMServer)\n\n\n@dataclass\nclass AgentArguments:\n    agent_name: str = field(default=\"single_turn_agent\", metadata={\"help\": \"choose which agent tool\"})\n    num_workers: int = field(default=1, metadata={\"help\": \"custom async server configs\"})\n    # custom async server configs\n    custom_async_server: CustomAsyncServer = field(default=None, metadata={\"help\": \"custom async server configs\"})\n    # Path to the custom async server implementation\n\n    train_cycle: int = field(default=None, metadata={\"help\": \"Train cycle\"})\n    process_path: str = field(default=None, metadata={\"help\": \"Path to the pre-process function\"})\n    pre_process_kwargs: dict = field(default_factory=dict, metadata={\"help\": \"Pre-process function arguments\"})\n    post_process_kwargs: dict = field(default_factory=dict, metadata={\"help\": \"Post-process function arguments\"})\n    env_path: list = field(default_factory=list, metadata={\"help\": \"Env path list\"})\n    obs_with_env: bool = field(default=False, metadata={\"help\": \"Rollout with obs from Env\"})\n    rewards_with_env: bool = field(default=False, metadata={\"help\": \"Use rewards from Env\"})\n    share_instance: int = field(default=None, metadata={\"help\": \"Use the same instance with the target agent group\"})\n\n\n@dataclass\nclass EngineArguments:\n    vllm: Dict[str, Any] = field(default_factory=lambda: {})\n    sglang: Dict[str, Any] = field(default_factory=lambda: {})\n\n\n@dataclass\nclass RolloutArguments:\n    name: str = field(default=\"vllm\", metadata={\"help\": \"Rollout engine\"})\n    temperature: float = field(default=1.0, metadata={\"help\": \"Sampling temperature\"})\n    top_k: int = field(default=-1, metadata={\"help\": \"Top-k sampling\"})\n    top_p: float = field(default=1.0, metadata={\"help\": \"Top-p sampling\"})\n    use_fire_sampling: bool = field(default=False, metadata={\"help\": \"Fire sampling optimization\"})\n    prompt_length: int = field(default=None, metadata={\"help\": \"Prompt length\"})\n    response_length: int = field(default=None, metadata={\"help\": \"Response length\"})\n    dtype: str = field(default=\"bfloat16\", metadata={\"help\": \"Compute dtype\"})\n    gpu_memory_utilization: float = field(default=0.5, metadata={\"help\": \"GPU memory usage\"})\n    ignore_eos: bool = field(default=False, metadata={\"help\": \"Ignore EOS tokens\"})\n    enforce_eager: bool = field(default=True, metadata={\"help\": \"Eager execution\"})\n    free_cache_engine: bool = field(default=True, metadata={\"help\": \"Free GPU cache\"})\n    load_format: str = field(default=\"dummy_dtensor\", metadata={\"help\": \"Weight loading format\"})\n    tensor_model_parallel_size: int = field(default=1, metadata={\"help\": \"Tensor parallelism\"})\n    max_num_batched_tokens: int = field(default=8192, metadata={\"help\": \"Max batched tokens\"})\n    max_model_len: Optional[int] = field(default=None, metadata={\"help\": \"Max model length\"})\n    max_num_seqs: int = field(default=1024, metadata={\"help\": \"Max concurrent sequences\"})\n    limit_images: Optional[int] = field(default=None, metadata={\"help\": \"support for multi-image data\"})\n    do_sample: bool = field(default=True, metadata={\"help\": \"Enable sampling\"})\n    n: int = field(default=1, metadata={\"help\": \"Number of responses\"})\n    log_prob_micro_batch_size: Optional[int] = field(\n        default=None, metadata={\"help\": \"[Deprecated] Log prob batch size\"}\n    )\n    log_prob_micro_batch_size_per_gpu: Optional[int] = field(\n        default=1, metadata={\"help\": \"Per-GPU log prob batch size\"}\n    )\n    log_prob_max_token_len_per_gpu: int = field(default=16384, metadata={\"help\": \"Max tokens per GPU\"})\n    log_prob_use_dynamic_bsz: bool = field(default=False, metadata={\"help\": \"Dynamic log prob batch size\"})\n    disable_log_stats: bool = field(default=True, metadata={\"help\": \"Whether or not disable log stats\"})\n    enable_chunked_prefill: bool = field(default=True, metadata={\"help\": \"Whether or not enable chunked prefill\"})\n    trust_remote_code: bool = field(default=True, metadata={\"help\": \"trust the code or not.\"})\n    val_kwargs: EvalSamplingArguments = field(default_factory=EvalSamplingArguments)\n    layer_name_map: LayerNameMapArguments = field(default_factory=LayerNameMapArguments)\n    seed: int = field(default=0, metadata={\"help\": \"The random seed\"})\n    mode: str = field(default=\"sync\", metadata={\"help\": \"sync: LLM, async: AsyncLLM\"})\n    multi_turn: MultiTurnArguments = field(default_factory=MultiTurnArguments)\n    micro_batch_size: Optional[int] = field(default=None, metadata={\"help\": \"Inference micro-batch size\"})\n    engine_kwargs: EngineArguments = field(default_factory=EngineArguments)\n    calculate_log_probs: bool = field(\n        default=False, metadata={\"help\": \"support logging rollout prob for debugging purpose\"}\n    )\n    agent: AgentArguments = field(default_factory=AgentArguments)\n    multi_stage_wake_up: bool = field(\n        default=False,\n        metadata={\n            \"help\": \"# Whether to wake up inference engine in multi-stage. (Wake up model weights first, \"\n            \"then resume kv cache)\"\n        },\n    )\n\n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n\n\n@dataclass\nclass RefArguments:\n    strategy: str = field(default=\"fsdp\", metadata={\"help\": \"Parallel strategy\"})\n    fsdp_config: FSDPArguments = field(default_factory=FSDPArguments, metadata={\"help\": \"Reference FSDP settings\"})\n    megatron: MegatronArguments = field(default_factory=MegatronArguments, metadata={\"help\": \"Megatron settings\"})\n    log_prob_micro_batch_size: Optional[int] = field(\n        default=None, metadata={\"help\": \"[Deprecated] Log prob batch size\"}\n    )\n    log_prob_micro_batch_size_per_gpu: Optional[int] = field(\n        default=1, metadata={\"help\": \"Per-GPU log prob batch size\"}\n    )\n    log_prob_use_dynamic_bsz: bool = field(default=False, metadata={\"help\": \"Dynamic log prob batch size\"})\n    log_prob_max_token_len_per_gpu: int = field(default=16384, metadata={\"help\": \"Max tokens per GPU\"})\n    ulysses_sequence_parallel_size: int = field(default=1, metadata={\"help\": \"Sequence parallel size\"})\n    use_remove_padding: bool = field(default=False, metadata={\"help\": \"Padding removal optimization\"})\n    use_fused_kernels: bool = field(default=False, metadata={\"help\": \"Kernels fuse optimization\"})\n    use_torch_compile: bool = field(default=True, metadata={\"help\": \"Whether or not use torch compile\"})\n    ppo_micro_batch_size: Optional[int] = field(default=None, metadata={\"help\": \"[Deprecated] Micro-batch size\"})\n    ppo_micro_batch_size_per_gpu: Optional[int] = field(default=None, metadata={\"help\": \"Per-GPU micro-batch size\"})\n    param_offload: bool = field(default=False, metadata={\"help\": \"Enable param offload or not\"})\n    grad_offload: bool = field(default=False, metadata={\"help\": \"Enable grad offload or not\"})\n    optimizer_offload: bool = field(default=False, metadata={\"help\": \"Enable optimizer offload or not\"})\n    load_weight: bool = field(default=True)\n    profile: dict[str, Any] = field(default_factory=dict, metadata={\"help\": \"Reference Profile settings\"})\n    shuffle: bool = field(default=False, metadata={\"help\": \"Data shuffling\"})\n    data_loader_seed: Optional[int] = field(default=None, metadata={\"help\": \"Data loader seed\"})\n    recompute_old_log_prob: bool = field(default=True, metadata={\"help\": \"recompute old log prob\"})\n    entropy_checkpointing: bool = field(default=False, metadata={\"help\": \"Enable entropy checkpointing\"})\n    entropy_from_logits_with_chunking: bool = field(\n        default=False, metadata={\"help\": \"Enable entropy from logits with chunking\"}\n    )\n    embodied_type: Optional[str] = field(\n        default=None, \n        metadata={\"help\": \"Embodied model type: 'openvla' or 'openvla-oft', None for non-embodied models\"}\n    )\n\n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n\n\n@dataclass\nclass ActorRolloutRefArguments:\n    hybrid_engine: bool = field(default=True, metadata={\"help\": \"Hybrid engine mode\"})\n    model: ModelArguments = field(default_factory=ModelArguments, metadata={\"help\": \"Base model settings\"})\n    actor: ActorArguments = field(default_factory=ActorArguments, metadata={\"help\": \"Actor configuration\"})\n    ref: RefArguments = field(default_factory=RefArguments, metadata={\"help\": \"Reference model settings\"})\n    rollout: RolloutArguments = field(default_factory=RolloutArguments, metadata={\"help\": \"Rollout parameters\"})\n    embodied: EmbodiedArguments = field(default_factory=EmbodiedArguments, metadata={\"help\": \"Embodied AI settings\"})\n\n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n\n\n@dataclass\nclass CriticArguments:\n    strategy: str = field(default=\"fsdp\", metadata={\"help\": \"Parallel strategy\"})\n    optim: OptimizerArguments = field(\n        default_factory=lambda: OptimizerArguments(lr=1e-5),\n        metadata={\"help\": \"Optimizer settings\"},\n    )\n    model: ModelArguments = field(\n        default_factory=lambda: ModelArguments(\n            path=\"~/models/deepseek-llm-7b-chat\", enable_gradient_checkpointing=True\n        ),\n        metadata={\"help\": \"Critic model\"},\n    )\n    fsdp_config: FSDPArguments = field(default_factory=FSDPArguments, metadata={\"help\": \"FSDP settings\"})\n    megatron: MegatronArguments = field(default_factory=MegatronArguments, metadata={\"help\": \"Megatron settings\"})\n    ppo_mini_batch_size: int = field(default=256, metadata={\"help\": \"PPO mini-batch size\"})\n    ppo_micro_batch_size: Optional[int] = field(default=None, metadata={\"help\": \"[Deprecated] Micro-batch size\"})\n    ppo_micro_batch_size_per_gpu: Optional[int] = field(default=None, metadata={\"help\": \"Per-GPU micro-batch size\"})\n    use_dynamic_bsz: bool = field(default=False, metadata={\"help\": \"Dynamic batch size\"})\n    ppo_epochs: int = field(default=1, metadata={\"help\": \"PPO epochs\"})\n    shuffle: bool = field(default=False, metadata={\"help\": \"Data shuffling\"})\n    grad_clip: float = field(default=1.0, metadata={\"help\": \"Gradient clipping\"})\n    cliprange_value: float = field(default=0.5, metadata={\"help\": \"Value clipping range\"})\n    ulysses_sequence_parallel_size: int = field(default=1, metadata={\"help\": \"Sequence parallel size\"})\n    forward_max_token_len_per_gpu: int = field(default=32768, metadata={\"help\": \"Forward max token length in per gpu\"})\n    load_weight: bool = field(default=True)\n    rollout_n: int = field(default=1, metadata={\"help\": \"rollout n\"})\n    checkpoint: CheckpointArguments = field(default_factory=CheckpointArguments, metadata={\"help\": \"Checkpoint configuration\"})\n    ppo_max_token_len_per_gpu: int = field(default=32768, metadata={\"help\": \"Max tokens per GPU\"})\n    loss_agg_mode: str = field(default=\"token-mean\", metadata={\"help\": \"token-mean, seq-mean-token-sum, seq-mean-token-mean\"})\n    profile: dict[str, Any] = field(default_factory=dict, metadata={\"help\": \"Critic Profile settings\"})\n    data_loader_seed: Optional[int] = field(default=None, metadata={\"help\": \"Data loader seed\"})\n\n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n\n\n@dataclass\nclass OverlongBufferArguments:\n    \"\"\"DAPO-specific overlong buffer configuration for handling sequences longer than max length.\"\"\"\n\n    enable: bool = field(default=False, metadata={\"help\": \"Enable overlong sequence buffer\"})\n    len: int = field(default=512, metadata={\"help\": \"Buffer length for overlong sequences\"})\n    penalty_factor: float = field(default=1.0, metadata={\"help\": \"Penalty factor for overlong sequences\"})\n    log: bool = field(default=False, metadata={\"help\": \"Enable logging of overlong buffer rewards and penalties\"})\n\n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n\n\n@dataclass\nclass RewardModelArguments:\n    enable: bool = field(default=False, metadata={\"help\": \"Enable reward model\"})\n    strategy: str = field(default=\"fsdp\", metadata={\"help\": \"Parallel strategy\"})\n    model: ModelArguments = field(\n        default_factory=lambda: ModelArguments(\n            path=\"~/models/deepseek-llm-7b-chat\", enable_gradient_checkpointing=True\n        ),\n        metadata={\"help\": \"Critic model\"},\n    )\n    fsdp_config: FSDPArguments = field(\n        default_factory=lambda: FSDPArguments(wrap_policy={\"min_num_params\": 0}, param_offload=False),\n        metadata={\"help\": \"FSDP configuration\"},\n    )\n    megatron: MegatronArguments = field(default_factory=MegatronArguments, metadata={\"help\": \"Megatron settings\"})\n    micro_batch_size: Optional[int] = field(default=None, metadata={\"help\": \"[Deprecated] Micro-batch size\"})\n    micro_batch_size_per_gpu: Optional[int] = field(default=None, metadata={\"help\": \"Per-GPU micro-batch size\"})\n    max_length: Optional[int] = field(default=None, metadata={\"help\": \"Max sequence length\"})\n    ulysses_sequence_parallel_size: int = field(default=1, metadata={\"help\": \"Sequence parallel size\"})\n    use_dynamic_bsz: bool = field(default=False, metadata={\"help\": \"Dynamic batch size\"})\n    reward_manager: str = field(default=\"naive\", metadata={\"help\": \"Reward management strategy\"})\n    forward_max_token_len_per_gpu: int = field(default=32768, metadata={\"help\": \"Forward max token length in per gpu\"})\n    load_weight: bool = field(default=True)\n    launch_reward_fn_async: bool = field(\n        default=False, metadata={\"help\": \"custom reward function executed async on CPU, during log_prob\"}\n    )\n    reward_kwargs: Dict[str, Any] = field(default_factory=lambda: {})\n    sandbox_fusion: Optional[Dict[str, Any]] = field(default=None)\n    overlong_buffer: OverlongBufferArguments = field(\n        default_factory=OverlongBufferArguments, metadata={\"help\": \"DAPO overlong buffer configuration\"}\n    )\n    profile: dict[str, Any] = field(default_factory=dict, metadata={\"help\": \"Reward Model Profile settings\"})\n    shuffle: bool = field(default=False, metadata={\"help\": \"Data shuffling\"})\n    data_loader_seed: Optional[int] = field(default=None, metadata={\"help\": \"Data loader seed\"})\n\n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n\n\n@dataclass\nclass KLCtrlArguments:\n    type: str = field(default=\"fixed\", metadata={\"help\": \"Type of KL Ctrl, fixed or adaptive\"})\n    kl_coef: float = field(default=0.001, metadata={\"help\": \"Coef of KL\"})\n    target_kl: Optional[float] = field(default=0.1, metadata={\"help\": \"Target KL value\"})\n    horizon: Optional[float] = field(default=10000, metadata={\"help\": \"Horizon of KL\"})\n\n\n@dataclass\nclass FilterGroupsArguments:\n    \"\"\"DAPO-specific filter groups configuration for dynamic sampling.\"\"\"\n\n    enable: bool = field(default=False, metadata={\"help\": \"Enable trajectory filtering based on variance\"})\n    metric: str = field(\n        default=\"acc\", metadata={\"help\": \"Metric used for filtering (acc, seq_final_reward, seq_reward)\"}\n    )\n    max_num_gen_batches: int = field(default=10, metadata={\"help\": \"Maximum generation batches before giving up\"})\n\n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n\n\n@dataclass\nclass AlgorithmArguments:\n    gamma: float = field(default=1.0, metadata={\"help\": \"Discount factor\"})\n    lam: float = field(default=1.0, metadata={\"help\": \"GAE lambda\"})\n    adv_estimator: str = field(default=\"gae\", metadata={\"help\": \"Advantage estimator\"})\n    kl_penalty: str = field(default=\"kl\", metadata={\"help\": \"KL penalty type\"})\n    kl_ctrl: KLCtrlArguments = field(default_factory=KLCtrlArguments)\n    use_kl_in_reward: bool = field(default=False, metadata={\"help\": \"Use KL In-Reward\"})\n    share_reward_in_agent: bool = field(default=True, metadata={\"help\": \"Shard Reward in Reward\"})\n    norm_adv_by_std_in_grpo: bool = field(default=True, metadata={\"help\": \"Whether to scale the GRPO advantage\"})\n    algorithm_name: str = field(default=\"grpo\", metadata={\"help\": \"Algorithm name, e.g., grpo, ppo, dapo\"})\n    weight_factor_in_cpgd: str = field(\n        default=\"STD_weight\",\n        metadata={\"help\": \"The weighting methods for advantage {STD_weight, clip_filter_like_weight, naive}\"},\n    )\n    workflow_type: str = field(\n        default=\"default\",\n        metadata={\"help\": \"Selects the workflow graph. 'default' for standard PPO/GRPO, 'dapo' for the DAPO workflow.\"},\n    )\n    filter_groups: FilterGroupsArguments = field(\n        default_factory=FilterGroupsArguments, metadata={\"help\": \"DAPO filter groups configuration\"}\n    )\n    use_pf_ppo: bool = field(default=False, metadata={\"help\": \"Whether to enable preference feedback PPO.\"})\n    pf_ppo: dict[str, Any] = field(default_factory=dict, metadata={\"help\": \" Preference feedback PPO settings.\"})\n    embodied_sampling: EmbodiedSamplingConfig = field(\n        default_factory=EmbodiedSamplingConfig, metadata={\"help\": \"Embodied dynamic sampling configuration\"}\n    )\n\n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n"
  },
  {
    "path": "siirl/params/parser.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport os\nimport hydra\nfrom pathlib import Path\nfrom typing import Any\nimport transformers\nfrom omegaconf import OmegaConf, DictConfig\n\nfrom siirl.params.training_args import SiiRLArguments\n\n\ndef _set_transformers_logging() -> None:\n    if os.getenv(\"SIIRL_LOG_VERBOSITY\", \"INFO\") in [\"DEBUG\", \"INFO\"]:\n        transformers.utils.logging.set_verbosity_info()\n        transformers.utils.logging.enable_default_handler()\n        transformers.utils.logging.enable_explicit_format()\n\n\ndef parse_config() -> SiiRLArguments:\n    \"\"\"Parse configuration using OmegaConf and convert to a SiiRLArguments instance.\"\"\"\n    parser = argparse.ArgumentParser()\n    _, overrides = parser.parse_known_args()\n    overrides = OmegaConf.from_cli(overrides)\n    # Convert OmegaConf config to a dictionary\n    siirl_config_dict = OmegaConf.to_container(overrides, resolve=True)\n\n    # Recursively convert nested configs\n    def convert_to_dataclass(obj: Any, dataclass_type: Any):\n        if isinstance(obj, dict):\n            fields = dataclass_type.__dataclass_fields__\n            kwargs = {}\n            for name, field_type in fields.items():\n                if name in obj:\n                    # Handle nested dataclasses\n                    if hasattr(field_type.type, \"__dataclass_fields__\"):\n                        kwargs[name] = convert_to_dataclass(obj[name], field_type.type)\n                    else:\n                        kwargs[name] = obj[name]\n            return dataclass_type(**kwargs)\n        return obj\n\n    # Convert root config\n    siirl_args = convert_to_dataclass(siirl_config_dict, SiiRLArguments)\n    _set_transformers_logging()\n    return siirl_args\n"
  },
  {
    "path": "siirl/params/profiler_args.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import asdict, dataclass, field\nfrom typing import Any, Dict\n\n\n@dataclass\nclass ProfilerArguments:\n    enable: bool = field(default=False, metadata={\"help\": \"Whether to enable profiling\"})\n    save_path: str = field(default=\"./prof_data\", metadata={\"help\": \"Storage path for collected data\"})\n    level: str = field(default=\"level1\", metadata={\"help\": \"Collection level-options are level_none, level0, level1, and level2\"})\n    with_memory: bool = field(default=False, metadata={\"help\": \"Whether to enable memory analysis\"})\n    record_shapes: bool = field(default=False, metadata={\"help\": \"Whether to record tensor shapes\"})\n    with_npu: bool = field(default=True, metadata={\"help\": \"Whether to collect device-side performance data\"})\n    with_cpu: bool = field(default=False, metadata={\"help\": \"Whether to collect host-side performance data\"})\n    with_module: bool = field(default=False, metadata={\"help\": \"Whether to record framework-layer Python call stack information\"})\n    with_stack: bool = field(default=False, metadata={\"help\": \"Whether to record operator call stack information\"})\n    analysis: bool = field(default=True, metadata={\"help\": \"Enables automatic data parsing\"})\n    discrete: bool = field(default=False, metadata={\n        \"help\": \"True for each task has its own database, False for all tasks in one training step share one database\"})\n    roles: list[str] = field(default_factory=lambda: [\"generate\",\"compute_reward\"], metadata={\n        \"help\": \"Used for discrete mode, optional values: generate, compute_reward, compute_old_log_prob, compute_ref_log_porb, compute_value, compute_advantage, train_critic, train_actor\"})\n    all_ranks: bool = field(default=False, metadata={\"help\": \"Whether to profile all ranks\"})\n    ranks: list[int] = field(default_factory=lambda: [0], metadata={\"help\": \"The ranks that will be profiled. [] or [0,1,...]\"})\n    profile_steps: list[int] = field(default_factory=lambda: [0], metadata={\"help\": \"The steps that will be profiled\"})\n    \n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n"
  },
  {
    "path": "siirl/params/training_args.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import asdict, dataclass, field\nfrom typing import Optional, Dict, List, Any\nfrom siirl.params.data_args import DataArguments\nfrom siirl.params.model_args import (\n    ActorRolloutRefArguments,\n    CriticArguments,\n    RewardModelArguments,\n    AlgorithmArguments,\n)\nfrom siirl.params.dag_args import DagArguments\nfrom siirl.params.profiler_args import ProfilerArguments\n\n\n@dataclass\nclass TrainingArguments:\n    total_epochs: int = field(default=30, metadata={\"help\": \"Total training epochs\"})\n    total_training_steps: Optional[int] = field(default=None, metadata={\"help\": \"Override training steps\"})\n    project_name: str = field(default=\"siirl_examples\", metadata={\"help\": \"Project name\"})\n    experiment_name: str = field(default=\"gsm8k\", metadata={\"help\": \"Experiment name\"})\n    logger: List[str] = field(\n        default_factory=lambda: [\"console\", \"wandb\"],\n        metadata={\"help\": \"Logging backends\"},\n    )\n    log_val_generations: int = field(default=0, metadata={\"help\": \"Validation samples to log\"})\n    nnodes: int = field(default=1, metadata={\"help\": \"Number of nodes\"})\n    n_gpus_per_node: int = field(default=8, metadata={\"help\": \"GPUs per node\"})\n    save_freq: int = field(default=-1, metadata={\"help\": \"Checkpoint frequency\"})\n    resume_mode: str = field(default=\"auto\", metadata={\"help\": \"Resume training mode\"})\n    resume_from_path: bool = field(default=False, metadata={\"help\": \"Resume from specific path\"})\n    test_freq: int = field(default=-1, metadata={\"help\": \"Testing frequency\"})\n    critic_warmup: int = field(default=0, metadata={\"help\": \"Critic warmup steps\"})\n    default_local_dir: str = field(\n        default=\"checkpoints/siirl_examples/gsm8k\",\n        metadata={\"help\": \"Checkpoint directory\"},\n    )\n    seed: int = field(default=1, metadata={\"help\": \"Train seed param\"})\n    should_log: bool = field(default=False, metadata={\"help\": \"Should print debug log for training\"})\n    should_save: bool = field(\n        default=False,\n        metadata={\"help\": \"Should save tokenized dataset to local disk and exit\"},\n    )\n    val_before_train: bool = field(default=True, metadata={\"help\": \"Whether or not to validate before train\"})\n    default_hdfs_dir: str = field(default=None, metadata={\"help\": \"Default hdfs dir path for checkpoints\"})\n    del_local_ckpt_after_load: bool = field(\n        default=False,\n        metadata={\"help\": \"Whether or not to delete local checkpoints after load\"},\n    )\n    val_only: bool = field(default=False, metadata={\"help\": \"Whether or not just eval only\"})\n    balance_batch: bool = field(\n        default=False,\n        metadata={\"help\": \"Whether or not to balance the number of valid tokens on each dp rank.\"},\n    )\n    remove_previous_ckpt_in_save: bool = field(\n        default=False,\n        metadata={\"help\": \"Whether or not to remove previous ckpt in save path.\"},\n    )\n    max_actor_ckpt_to_keep: int = field(default=100, metadata={\"help\": \"Maximum number of actor ckpts.\"})\n    max_critic_ckpt_to_keep: int = field(default=100, metadata={\"help\": \"Maximum number of critic ckpts.\"})\n    ray_wait_register_center_timeout: int = field(default=300, metadata={\"help\": \"The timeout for ray worker group to wait for the register center to be ready\"})\n    validation_data_dir: Optional[str] = field(default=None, metadata={\"help\": \"Validation data directory.\"})\n    rollout_data_dir: Optional[str] = field(default=None, metadata={\"help\": \"Rollout data directory.\"})\n    device: Optional[str] = field(default=\"cuda\", metadata={\"help\": \"Training device.\"})\n\n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n\n\n@dataclass\nclass CustomRewardArguments:\n    path: str = field(default=None, metadata={\"help\": \"Custom reward function import file path\"})\n    name: str = field(default=\"compute_score\", metadata={\"help\": \"Custom reward function name\"})\n    reward_kwargs: Dict[str, Any] = field(default_factory=lambda: {})\n\n\n@dataclass\nclass SiiRLArguments:\n    data: DataArguments = field(default_factory=DataArguments)\n    actor_rollout_ref: ActorRolloutRefArguments = field(default_factory=ActorRolloutRefArguments)\n    critic: CriticArguments = field(default_factory=CriticArguments)\n    reward_model: RewardModelArguments = field(default_factory=RewardModelArguments)\n    algorithm: AlgorithmArguments = field(default_factory=AlgorithmArguments)\n    trainer: TrainingArguments = field(default_factory=TrainingArguments)\n    custom_reward_function: CustomRewardArguments = field(default_factory=CustomRewardArguments)\n    dag: DagArguments = field(default_factory=DagArguments)\n    profiler: ProfilerArguments = field(default_factory=ProfilerArguments)\n\n    def to_dict(self) -> Dict[str, Any]:\n        return asdict(self)\n"
  },
  {
    "path": "siirl/third_party/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/third_party/sglang/__init__.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/third_party/sglang/parallel_state.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2023 The SGlang team.\n# Adapted from\n# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py\n# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\n\"\"\"Model and data parallel groups.\"\"\"\n\nimport os\nfrom typing import Optional\n\nimport sglang.srt.distributed.parallel_state as ps\nimport torch\nimport torch.distributed\nfrom sglang.srt.distributed.parallel_state import (\n    get_pp_group,\n    get_world_group,\n    init_distributed_environment,\n    init_model_parallel_group,\n)\n\n\"\"\"\nThis version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron.\n- We assume the Megatron tp+dp+pp world is already established before calling this function.\n\n\"\"\"\n\n# Device mesh for using DTensor\n_DEVICE_MESH = None\n\n# Tensor model parallel group that the current rank belongs to.\n_TP = None\n# Pipeline model parallel group that the current rank belongs to.\n_PP = None\n\n\n# This method is for initializing the ParallelGroup when using HybridEngine\n# NOTE(linjunrong): this function is for megatron\ndef initialize_parallel_state(\n    distributed_init_method: str = \"env://\",\n    backend: str = \"nccl\",\n    tensor_model_parallel_size: int = 1,\n    num_tp_per_train_tp: int = 1,\n    pipeline_model_parallel_size: int = 1,\n):\n    # torch.distributed.all_reduce does not free the input tensor until\n    # the synchronization point. This causes the memory usage to grow\n    # as the number of all_reduce calls increases. This env var disables\n    # this behavior.\n    # Related issue:\n    # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573\n    os.environ[\"TORCH_NCCL_AVOID_RECORD_STREAMS\"] = \"1\"\n\n    # NOTE(sgm): Modify for siirl, Env vars will be set by TORCHRUN.\n    rank = int(os.getenv(\"RANK\", \"-1\"))\n    local_rank = int(os.getenv(\"LOCAL_RANK\", \"0\"))\n\n    # Use the world_size set by TORCHRUN\n    world_size = int(os.getenv(\"WORLD_SIZE\", \"-1\"))\n    assert world_size != -1, \"The world_size is set to -1, not initialized by TORCHRUN\"\n    init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend)\n    if torch.distributed.get_world_size() > 1:\n        # NOTE: build a separate inference group with infer tp & micro dp\n        initialize_model_parallel_for_sglang(\n            tensor_model_parallel_size=tensor_model_parallel_size,\n            num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp,\n        )\n    else:\n        initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)\n\n\n# NOTE(linjunrong): After init SGLang rollout using class EngineFragment, user should always remember to call\n# this function to sync the _TP, _PP define at the beginning of this file. Otherwise, only the counterparts\n# inside sglang.srt.distributed are init as ProcessGroup, the symbols defined in this file remain as None.\n# It could be weird to maintain two _TP and _PP, I follow the same way to maintain an extra ones for\n# siirl itself as how it was done in siirl.third_party.vllm.parallel_state. Note that the process is a little\n# bit different\ndef ensure_model_parallel_initialized(\n    tensor_model_parallel_size: int,\n    pipeline_model_parallel_size: int = 1,\n    backend: Optional[str] = None,\n) -> None:\n    \"\"\"Helper to initialize model parallel groups if they are not initialized,\n    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected\n    values if the model parallel groups are initialized.\n    \"\"\"\n    # get the backend of _DEVICE_WORLD_GROUP\n    backend = backend or torch.distributed.get_backend(get_world_group().device_group)\n    if not model_parallel_is_initialized():\n        initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)\n        return\n\n    assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (\n        f\"tensor parallel group already initialized, but of unexpected size: \"\n        f\"{get_tensor_model_parallel_world_size()=} vs. {tensor_model_parallel_size=}\"\n    )\n    pp_world_size = get_pp_group().world_size\n    assert pp_world_size == pipeline_model_parallel_size, (\n        f\"pipeline parallel group already initialized, but of unexpected size: {pp_world_size=} vs. \"\n        f\"{pipeline_model_parallel_size=}\"\n    )\n\n\n# TODO(sgm): deviate from the v0.5.4, not pp now\n# NOTE(linjunrong): the SGLang version using _TP instead of ps._TP\ndef model_parallel_is_initialized():\n    \"\"\"Check if tensor and pipeline parallel groups are initialized.\"\"\"\n    return _TP is not None\n    # and _PIPELINE_MODEL_PARALLEL_GROUP is not None)\n\n\ndef initialize_model_parallel_for_sglang(\n    tensor_model_parallel_size: int,\n    num_tensor_model_parallel_groups_per_train_tp: int = 1,\n    pipeline_model_parallel_size: int = 1,\n) -> None:\n    pass\n\n    # Get world size and rank. Ensure some consistencies.\n    assert torch.distributed.is_initialized()\n\n    assert isinstance(tensor_model_parallel_size, int)\n\n    # assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group\n    # assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group\n\n    # Build the tensor model-parallel groups.\n    assert ps._TP is None, \"tensor model parallel group is already initialized\"\n\n    global _TP\n\n    world_size: int = torch.distributed.get_world_size()\n\n    backend = torch.distributed.get_backend()\n\n    num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size\n\n    if num_tensor_model_parallel_groups_per_train_tp == 1:\n        # if tensor_model_parallel_size == train_tensor_parallel_size:\n        # using the same tp group as Megatron/vllm\n        assert _TP is None, \"tensor model parallel group is already initialized\"\n        group_ranks = []\n        for i in range(num_tensor_model_parallel_groups):\n            ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)\n            group_ranks.append(ranks)\n        _TP = init_model_parallel_group(\n            group_ranks=group_ranks,\n            local_rank=get_world_group().local_rank,\n            backend=backend,\n            use_custom_allreduce=False,  # TODO: check why True is not work in Ray trainer\n            use_message_queue_broadcaster=True,\n        )\n        ps._TP = _TP\n        # _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine\n    else:\n        # initialize a micro_dp group and a tp group\n        # assume training tp=4, infer tp=2, then, weight is partitioned as\n        # [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference\n\n        # Build the inference tp groups\n        # train_tp = train_tensor_parallel_size\n        train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size\n        # num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size\n        assert _TP is None, \"tensor model parallel group is already initialized\"\n        group_ranks = []\n        for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp):\n            start = train_tp * i\n            end = train_tp * (i + 1)\n            for j in range(num_tensor_model_parallel_groups_per_train_tp):\n                ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp))\n                for i in range(len(ranks)):\n                    ranks[i] += j\n                group_ranks.append(ranks)\n        _TP = init_model_parallel_group(\n            group_ranks=group_ranks,\n            local_rank=get_world_group().local_rank,\n            backend=backend,\n            use_custom_allreduce=False,  # TODO: check why True is not work in Ray trainer\n            use_message_queue_broadcaster=True,\n        )\n        ps._TP = _TP\n\n    # Build the pipeline model-parallel groups.\n    # global _PIPELINE_MODEL_PARALLEL_GROUP\n    # global _PIPELINE_GLOBAL_RANKS\n    # assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, (\"pipeline model parallel group is already initialized\")\n\n    # ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group()\n    # ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks()\n\n    # TODO: init using device mesh (not support hybrid engine now)\n    # Build the pipeline model-parallel groups.\n    num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n    global _PP\n    assert _PP is None, \"pipeline model parallel group is already initialized\"\n    group_ranks = []\n    for i in range(num_pipeline_model_parallel_groups):\n        ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))\n        group_ranks.append(ranks)\n    # pipeline parallel does not need custom allreduce\n    _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False)\n    ps._PP = _PP  # for siirl\n\n\ndef initialize_model_parallel(\n    tensor_model_parallel_size: int = 1,\n    pipeline_model_parallel_size: int = 1,\n    backend: Optional[str] = None,\n) -> None:\n    \"\"\"\n    NOTE: This method is a hack from the open-sourced version without\n    assertion of world_size = tp * pp\n\n    Initialize model parallel groups.\n\n    Arguments:\n        tensor_model_parallel_size: number of GPUs used for tensor model\n            parallelism.\n        pipeline_model_parallel_size: number of GPUs used for pipeline model\n            parallelism.\n\n    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we\n    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize\n    the model pipeline. The present function will\n    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:\n        4 tensor model-parallel groups:\n            [g0, g1], [g2, g3], [g4, g5], [g6, g7]\n        2 pipeline model-parallel groups:\n            [g0, g2, g4, g6], [g1, g3, g5, g7]\n    Note that for efficiency, the caller should make sure adjacent ranks\n    are on the same DGX box. For example if we are using 2 DGX-1 boxes\n    with a total of 16 GPUs, rank 0 to 7 belong to the first box and\n    ranks 8 to 15 belong to the second box.\n    \"\"\"\n    # Get world size and rank. Ensure some consistencies.\n    assert torch.distributed.is_initialized()\n    world_size: int = torch.distributed.get_world_size()\n    backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group)\n\n    # NOTE(sgm) we don't assert world_size == tp * pp\n    # DP is not managed by vllm but by the siiRL WorkerGroup\n    # if (world_size !=\n    #         tensor_model_parallel_size * pipeline_model_parallel_size):\n    #     raise RuntimeError(\n    #         f\"world_size ({world_size}) is not equal to \"\n    #         f\"tensor_model_parallel_size ({tensor_model_parallel_size}) x \"\n    #         f\"pipeline_model_parallel_size ({pipeline_model_parallel_size})\")\n\n    num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n\n    global _TP\n    assert _TP is None, \"tensor model parallel group is already initialized\"\n    group_ranks = []\n    for i in range(num_tensor_model_parallel_groups):\n        ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size))\n        group_ranks.append(ranks)\n\n    # message queue broadcaster is only used in tensor model parallel group\n    if ps._TP is not None:\n        _TP = ps._TP\n    else:\n        _TP = init_model_parallel_group(\n            group_ranks,\n            get_world_group().local_rank,\n            backend,\n            use_custom_allreduce=False,  # TODO: check why True is not work in Ray trainer\n            use_message_queue_broadcaster=True,\n        )\n        ps._TP = _TP\n\n    # TODO: init using device mesh (not support hybrid engine now)\n    # Build the pipeline model-parallel groups.\n    num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n    global _PP\n    assert _PP is None, \"pipeline model parallel group is already initialized\"\n    group_ranks = []\n    for i in range(num_pipeline_model_parallel_groups):\n        ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))\n        group_ranks.append(ranks)\n    # pipeline parallel does not need custom allreduce\n    if ps._TP is not None:\n        _PP = ps._TP\n    else:\n        _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False)\n        ps._PP = _PP\n\n\n\"\"\"\nDevice mesh utilities\n\"\"\"\n\n\ndef get_device_mesh():\n    assert _DEVICE_MESH is not None, \"device mesh is not initialized\"\n    return _DEVICE_MESH\n\n\n\"\"\"\nTensor model parallel utilities\n\"\"\"\n\n\n# NOTE(linjunrong): In the vllm version parallel_state.py. siirl created its own _TP and _PP as siirl want to use\n# the process group for some extra purpose. Under the hood, there is no difference between them and the original\n# one in vllm.distributed.parallel_state. However, the implementation need to hack the init process of inference\n# engine, as we do not maintain another SGLang here, I just use the original _TP and _PP directly.\ndef get_tensor_model_parallel_group():\n    \"\"\"Get the tensor model parallel group the caller rank belongs to.\"\"\"\n\n    assert _TP is not None, \"tensor model parallel group is not initialized\"\n    return _TP.device_group\n\n\ndef get_tensor_model_parallel_world_size():\n    \"\"\"Return world size for the tensor model parallel group.\"\"\"\n    return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())\n\n\ndef get_tensor_model_parallel_rank():\n    \"\"\"Return my rank for the tensor model parallel group.\"\"\"\n    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())\n\n\ndef get_tensor_model_parallel_src_rank():\n    \"\"\"Calculate the global rank corresponding to the first local rank\n    in the tensor model parallel group.\"\"\"\n    global_rank = torch.distributed.get_rank()\n    local_world_size = get_tensor_model_parallel_world_size()\n    return (global_rank // local_world_size) * local_world_size\n"
  },
  {
    "path": "siirl/user_interface/filter_interface/__init__.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .dapo import dynamic_sampling\n\n__all__ = [\"dynamic_sampling\"]\n"
  },
  {
    "path": "siirl/user_interface/filter_interface/dapo.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom collections import defaultdict\nfrom typing import Any, Dict\n\nimport numpy as np\nfrom tensordict import TensorDict\nfrom siirl.params import SiiRLArguments\nfrom siirl.dag_worker.data_structures import NodeOutput\nfrom siirl.data_coordinator.sample import filter_tensordict\n\ndef dynamic_sampling(config: SiiRLArguments, batch: TensorDict, **kwargs: Any) -> NodeOutput:\n    \"\"\"\n    Performs dynamic sampling by filtering trajectory groups based on metric variance.\n\n    Args:\n        config (SiiRLArguments): The global training arguments from the configuration.\n        batch (TensorDict): The input data batch for this step, which must contain 'uid'\n                           in `non_tensor_batch` and the specified metric for filtering.\n        node_config (Dict[str, Any]): The configuration specific to this node (not used here).\n        **kwargs (Any): Additional keyword arguments (not used here).\n\n    Returns:\n        NodeOutput: An output object containing the filtered batch and metrics about the filtering process.\n\n    Raises:\n        KeyError: If the specified metric for filtering cannot be found in the batch and\n                  cannot be computed from available data.\n    \"\"\"\n    filter_config = config.algorithm.filter_groups\n\n    # If filtering is disabled in the main config, bypass the logic and return the original batch.\n    if not filter_config.enable:\n        return NodeOutput(batch=batch, metrics={\"sampling/kept_trajectories_ratio\": 1.0})\n\n    metric_name = filter_config.metric\n    initial_traj_count = len(batch) if batch is not None else 0\n\n    # Ensure the filtering metric exists. If not, try to compute it on-the-fly,\n    # mirroring the behavior of dapo_ray_trainer.py.\n    if metric_name not in batch:\n        if metric_name == \"seq_final_reward\" and \"token_level_rewards\" in batch.batch:\n            # Calculate from token-level rewards if necessary.\n            batch[\"seq_final_reward\"] = batch[\"token_level_rewards\"].sum(dim=-1).cpu().numpy()\n        elif metric_name == \"seq_reward\" and \"token_level_scores\" in batch:\n            # Calculate from token-level scores if necessary.\n            batch[\"seq_reward\"] = batch[\"token_level_scores\"].sum(dim=-1).cpu().numpy()\n        else:\n            # If the metric cannot be found or computed, it's a configuration error.\n            raise KeyError(f\"Metric '{metric_name}' for group filtering not found in batch and could not be computed. Available non-tensor keys: {list(batch.keys())}\")\n\n    # Group trajectories by UID and collect their corresponding metric values.\n    prompt_uid_to_metric_vals = defaultdict(list)\n    uids = batch[\"uid\"]\n    metric_values = batch[metric_name]\n\n    def _to_uid_key(uid):\n        \"\"\"Convert uid to a hashable key (handles tensor, np.str_, and regular values).\"\"\"\n        if hasattr(uid, 'item'):\n            return uid.item()  # tensor -> Python scalar\n        elif hasattr(uid, 'tolist'):\n            return uid.tolist()  # numpy array/scalar -> Python value\n        else:\n            return str(uid) if not isinstance(uid, (int, str)) else uid\n\n    for i in range(len(uids)):\n        uid_key = _to_uid_key(uids[i])\n        prompt_uid_to_metric_vals[uid_key].append(metric_values[i])\n\n    # Calculate the standard deviation of the metric for each group of trajectories.\n    prompt_uid_to_metric_std = {\n        prompt_uid: np.std(metric_vals) for prompt_uid, metric_vals in prompt_uid_to_metric_vals.items()\n    }\n\n    # Decide which prompts (UIDs) to keep. A group is kept if its metric values\n    # show variance (std > 0) or if it's a single-sample group (which cannot have variance).\n    kept_prompt_uids = {\n        uid for uid, std in prompt_uid_to_metric_std.items() if std > 0 or len(prompt_uid_to_metric_vals[uid]) == 1\n    }\n\n    # Find the indices of all trajectories that belong to the kept groups.\n    if not kept_prompt_uids:\n        kept_traj_indices = []\n    else:\n        kept_traj_indices = [\n            idx for idx in range(len(uids))\n            if _to_uid_key(uids[idx]) in kept_prompt_uids\n        ]\n\n    # Filter the original batch by slicing it with the collected indices.\n    # The TensorDict object natively supports this slicing operation.\n    filtered_batch = filter_tensordict(batch, kept_traj_indices)\n\n    # Calculate and return metrics about the filtering process for logging and analysis.\n    final_traj_count = len(filtered_batch) if filtered_batch is not None else 0\n    kept_ratio = final_traj_count / initial_traj_count if initial_traj_count > 0 else 1.0\n    metrics = {\"dapo_sampling/kept_trajectories_ratio\": kept_ratio, \"dapo_sampling/initial_trajectories\": initial_traj_count, \"dapo_sampling/final_trajectories\": final_traj_count, \"dapo_sampling/kept_groups\": len(kept_prompt_uids), \"dapo_sampling/total_groups\": len(prompt_uid_to_metric_vals)}\n    \n    # Also return the indices for np.ndarray filtering in the next node\n    metrics['dapo_sampling/filtered_indices'] = kept_traj_indices\n\n    return NodeOutput(batch=filtered_batch, metrics=metrics)\n"
  },
  {
    "path": "siirl/user_interface/filter_interface/embodied.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom collections import Counter\nfrom typing import Any, Dict, List, Tuple\n\nimport torch\nfrom loguru import logger\nfrom tensordict import TensorDict\n\nfrom siirl.params import SiiRLArguments\nfrom siirl.dag_worker.data_structures import NodeOutput\nfrom siirl.data_coordinator import SampleInfo\nfrom siirl.data_coordinator.protocol import select_idxs\n\ndef verify(\n    data: TensorDict,\n) -> Tuple[List[float], Dict[str, float], Dict[str, float], Dict[str, float]]:\n    \"\"\"Calculates scores and enriches the batch with accuracy information.\n\n    This function uses the 'complete' field from the batch as the ground truth\n    for scores. It then writes the calculated scores ('acc') and a format\n    correctness tensor back into the input `data` object for downstream use.\n\n    Args:\n        data: The TensorDict object containing batch data, including 'responses'\n            and 'complete' tensors.\n\n    Returns:\n        A tuple containing:\n        - scores_list: A Python list of float scores for each sample.\n        - reward_metrics: A dictionary of aggregate reward metrics (e.g., mean score).\n        - format_metrics: A dictionary of aggregate format metrics.\n        - reward_format_metrics: A dictionary of reward metrics excluding format issues.\n    \"\"\"\n    # --- 1. Access Tensors and Metadata ---\n    responses = data[\"responses\"]\n    completes = data[\"complete\"]\n    device = responses.device\n    batch_size = responses.size(0)\n\n    # --- 2. Sanity Check ---\n    assert completes.size(0) == batch_size, \"Batch size mismatch between 'completes' and 'responses'.\"\n\n    # --- 3. Create Score Tensors ---\n    scores_tensor = completes.float()\n    # Assume format is always correct for this verification step.\n    format_tensor = torch.ones(batch_size, dtype=torch.float32, device=device)\n\n    data[\"acc\"] = scores_tensor\n    data[\"format_correctness\"] = format_tensor\n\n    # --- 4. Calculate Aggregate Metrics ---\n    mean_score = scores_tensor.mean().item()\n    reward_metrics = {\"all\": mean_score}\n    format_metrics = {\"all\": 1.0}  # Always 1.0 based on the assumption above\n    reward_format_metrics = {\"all\": mean_score}\n\n    return scores_tensor.tolist(), reward_metrics, format_metrics, reward_format_metrics\n\n\ndef _filter_batch(batch: TensorDict, n_samples: int, config: SiiRLArguments) -> TensorDict:\n    \"\"\"Filters a batch based on accuracy and truncation criteria.\n\n    Filtering is performed at the prompt level. If any of the `n_samples`\n    responses for a single prompt fails a check, all `n_samples` responses\n    for that prompt are discarded.\n\n    Args:\n        batch: The TensorDict object to be filtered. Must contain 'acc' tensor.\n        n_samples: The number of responses generated per prompt.\n        config: Configuration object containing filter settings.\n\n    Returns:\n        A new, potentially smaller, TensorDict object containing the filtered data.\n    \"\"\"\n    device = batch[\"responses\"].device\n    num_prompts = len(batch) // n_samples\n    rank = int(os.environ.get(\"RANK\", \"0\"))\n\n    embodied_sampling = config.algorithm.embodied_sampling\n    filter_accuracy = embodied_sampling.filter_accuracy\n\n    if filter_accuracy:\n        acc_matrix = batch[\"acc\"].reshape(num_prompts, n_samples)\n        prompt_mean_acc = acc_matrix.mean(dim=-1)\n\n        if config.dag.enable_perf:\n            counts = Counter(prompt_mean_acc.tolist())\n            log_lines = [f\"Accuracy Distribution ({len(prompt_mean_acc)} prompts):\"]\n            for score, count in sorted(counts.items()):\n                log_lines.append(f\"  - Score {score:.2f}: {count} prompts\")\n            logger.info(\"\\n\".join(log_lines))\n\n        accuracy_lower_bound = embodied_sampling.accuracy_lower_bound\n        accuracy_upper_bound = embodied_sampling.accuracy_upper_bound\n        acc_mask = (prompt_mean_acc >= accuracy_lower_bound) & (prompt_mean_acc <= accuracy_upper_bound)\n    else:\n        acc_mask = torch.ones(num_prompts, dtype=torch.bool, device=device)\n\n    filter_truncated = embodied_sampling.filter_truncated\n    if filter_truncated:\n        if \"finish_step\" in batch:\n            finish_steps = batch[\"finish_step\"].reshape(num_prompts, n_samples)\n            max_steps = config.actor_rollout_ref.embodied.env.max_steps\n            has_truncated = (finish_steps >= max_steps).any(dim=-1)\n\n            if rank == 0:\n                truncated = int(has_truncated.sum().item())\n                kept = len(has_truncated) - truncated\n                logger.info(f\"Truncation: {truncated} truncated, {kept} kept (out of {len(has_truncated)} prompts)\")\n\n            trunc_mask = ~has_truncated\n        else:\n            logger.warning(\"No 'finish_step' field found in batch. Skipping truncation filtering.\")\n            trunc_mask = torch.ones(num_prompts, dtype=torch.bool, device=device)\n    else:\n        trunc_mask = torch.ones(num_prompts, dtype=torch.bool, device=device)\n\n    combined_mask = acc_mask & trunc_mask\n\n    if rank == 0:\n        kept = combined_mask.sum().item()\n        filtered = num_prompts - kept\n        logger.info(f\"Filter: {num_prompts} prompts → {kept} kept, {filtered} filtered\")\n\n    final_mask = combined_mask.repeat_interleave(n_samples)\n    filtered_batch = select_idxs(batch, final_mask)\n\n    return filtered_batch\n\n\ndef _compute_embodied_verification_metrics(\n    batch: TensorDict,\n    config: SiiRLArguments,\n) -> Dict[str, float]:\n    \"\"\"\n    Compute Embodied AI-specific metrics during verification phase.\n    \n    Args:\n        batch: The batch being verified\n        config: Configuration arguments\n    \n    Returns:\n        Dictionary of Embodied verification metrics\n    \"\"\"\n    try:\n        from siirl.utils.embodied.metrics import (\n            compute_rollout_metrics,\n        )\n        \n        metrics = {}\n        \n        # Prepare batch dict for metrics computation\n        batch_dict = {\n            'responses': batch.get('responses'),\n            'complete': batch.get('complete'),\n            'finish_step': batch.get('finish_step'),\n        }\n        \n        # Add optional fields\n        if 'pixel_values' in batch:\n            batch_dict['pixel_values'] = batch['pixel_values']\n        if 'acc' in batch:\n            batch_dict['acc'] = batch['acc']\n        \n        # Compute rollout metrics\n        rollout_metrics = compute_rollout_metrics(batch_dict, config)\n        for key, value in rollout_metrics.items():\n            # Add verify_ prefix to distinguish from actor metrics\n            metrics[f\"verify_{key}\"] = value\n        \n        return metrics\n        \n    except Exception as e:\n        logger.debug(f\"Failed to compute Embodied verification metrics: {e}\")\n        return {}\n\n\ndef embodied_local_rank_sampling(\n    config: SiiRLArguments,\n    batch: TensorDict,\n    **kwargs: Any,\n) -> NodeOutput:\n    \"\"\"Performs verification, metric collection, and optional filtering on a batch.\n\n    This function orchestrates the post-generation processing pipeline for a batch\n    of samples. It first verifies all samples, then filters them according to\n    configuration, and finally attaches the calculated metrics to the resulting batch.\n\n    Args:\n        config: Global SiiRL configuration arguments.\n        batch: The input TensorDict batch from the generation stage.\n        node_config: Configuration specific to this execution node.\n        **kwargs: Additional keyword arguments (unused).\n\n    Returns:\n        A NodeOutput object containing the processed (and potentially filtered) batch.\n    \"\"\"\n    import os\n\n    original_batch_size = batch.batch_size[0] if hasattr(batch, 'batch_size') else len(batch)\n    rank = int(os.environ.get(\"RANK\", \"0\"))\n\n    _, reward_metrics, format_metrics, reward_format_metrics = verify(batch)\n\n    sample_metrics = {}\n\n    enable_embodied_metrics = True\n    if hasattr(config, 'actor_rollout_ref') and hasattr(config.actor_rollout_ref, 'embodied'):\n        if config.actor_rollout_ref.embodied is not None:\n            if hasattr(config.actor_rollout_ref.embodied, 'enable_vla_metrics'):\n                enable_embodied_metrics = config.actor_rollout_ref.embodied.enable_vla_metrics\n\n    if enable_embodied_metrics:\n        embodied_verification_metrics = _compute_embodied_verification_metrics(batch, config)\n        sample_metrics.update(embodied_verification_metrics)\n\n    embodied_sampling = config.algorithm.embodied_sampling\n    if embodied_sampling.filter_accuracy or embodied_sampling.filter_truncated:\n        n_samples = config.actor_rollout_ref.rollout.n\n        processed_batch = _filter_batch(batch, n_samples, config)\n\n        if rank == 0:\n            filtered_size = processed_batch.batch_size[0] if hasattr(processed_batch, 'batch_size') else len(processed_batch)\n            filtered_count = original_batch_size - filtered_size\n            success_rate = reward_metrics.get('all', 0.0)\n            logger.info(f\"[SAMPLING] {original_batch_size} samples → filtered {filtered_count} → {filtered_size} remaining | Success: {success_rate:.1%}\")\n    else:\n        processed_batch = batch\n        if rank == 0:\n            success_rate = reward_metrics.get('all', 0.0)\n            logger.info(f\"[SAMPLING] {original_batch_size} samples | Success: {success_rate:.1%} (no filter)\")\n\n    if processed_batch is not None:\n        for key, tensor in processed_batch.items():\n            if isinstance(tensor, torch.Tensor) and tensor.device.type != 'cpu':\n                processed_batch[key] = tensor.cpu()\n\n    return NodeOutput(batch=processed_batch, metrics=sample_metrics)"
  },
  {
    "path": "siirl/user_interface/rewards_interface/custom_gsm8k_reward.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport re\n\n\ndef extract_solution(solution_str, method=\"strict\"):\n    assert method in [\"strict\", \"flexible\"]\n\n    if method == \"strict\":\n        # this also tests the formatting of the model\n        solution = re.search(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n        if solution is None:\n            final_answer = None\n        else:\n            final_answer = solution.group(0)\n            final_answer = final_answer.split(\"#### \")[1].replace(\",\", \"\").replace(\"$\", \"\")\n    elif method == \"flexible\":\n        answer = re.findall(\"(\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n        final_answer = None\n        if len(answer) == 0:\n            # no reward is there is no answer\n            pass\n        else:\n            invalid_str = [\"\", \".\"]\n            # find the last number that is not '.'\n            for final_answer in reversed(answer):\n                if final_answer not in invalid_str:\n                    break\n    return final_answer\n\ndef compute_score(data_source, solution_str, ground_truth, extra_info):\n    \"\"\"The scoring function for GSM8k.\n\n    Reference: Trung, Luong, et al. \"Reft: Reasoning with reinforced fine-tuning.\" Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.\n\n    Args:\n        solution_str: the solution text\n        ground_truth: the ground truth\n        method: the method to extract the solution, choices are 'strict' and 'flexible'\n        format_score: the score for the format\n        score: the score for the correct answer\n    \"\"\"\n    method=\"strict\"\n    format_score=0.0\n    score=1.0\n    answer = extract_solution(solution_str=solution_str, method=method)\n    if answer is None:\n        return 0\n    else:\n        if answer == ground_truth:\n            return score\n        else:\n            return format_score\n"
  },
  {
    "path": "siirl/utils/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/utils/checkpoint/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/utils/checkpoint/checkpoint_manager.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\nimport random\nimport shutil\nimport tempfile\nfrom pathlib import Path\nfrom typing import Optional, Union\n\nimport numpy as np\nimport torch\nimport torch.distributed\nfrom filelock import FileLock\nfrom loguru import logger\nfrom omegaconf import DictConfig\nfrom transformers import PreTrainedTokenizer, ProcessorMixin\n\nfrom siirl.utils.extras.device import is_cuda_available, is_npu_available\nfrom siirl.params.model_args import CheckpointArguments\n\n\nclass BaseCheckpointManager:\n    \"\"\"\n    A checkpoint manager that saves and loads\n    - model\n    - optimizer\n    - lr_scheduler\n    - extra_states\n    in a SPMD way.\n\n    We save\n    - sharded model states and optimizer states\n    - full lr_scheduler states\n    - huggingface tokenizer and config for ckpt merge\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        optimizer: torch.optim.Optimizer,\n        lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None,\n        processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None,\n        checkpoint_contents: Optional[list] = None,\n        checkpoint_config: DictConfig | CheckpointArguments = None,\n    ):\n        self.checkpoint_config = checkpoint_config\n        if checkpoint_config:\n            checkpoint_load_contents = checkpoint_config.load_contents\n            checkpoint_save_contents = checkpoint_config.save_contents\n        else:\n            checkpoint_load_contents = None\n            checkpoint_save_contents = None\n\n        if checkpoint_load_contents is None:\n            checkpoint_load_contents = [\"model\", \"optimizer\", \"extra\"]\n        if checkpoint_save_contents is None:\n            checkpoint_save_contents = [\"model\", \"optimizer\", \"extra\"]\n\n        if checkpoint_contents is None:\n            checkpoint_contents = [\"model\", \"optimizer\", \"extra\"]\n\n        self.previous_global_step = None\n        self.previous_saved_paths = []\n\n        self.model = model\n        self.optimizer = optimizer\n        self.lr_scheduler = lr_scheduler\n        self.processing_class = processing_class\n        self.checkpoint_contents = checkpoint_contents\n        self.checkpoint_load_contents = checkpoint_load_contents\n        self.checkpoint_save_contents = checkpoint_save_contents\n\n        self.rank = torch.distributed.get_rank()\n        self.world_size = torch.distributed.get_world_size()\n\n    @property\n    def should_save_model(self) -> bool:\n        \"\"\"\n        Returns True if 'model' is in checkpoint_save_contents, indicating the model state should be saved.\n        \"\"\"\n        return \"model\" in self.checkpoint_save_contents\n\n    @property\n    def should_save_optimizer(self) -> bool:\n        \"\"\"\n        Returns True if 'optimizer' is in checkpoint_save_contents, indicating the optimizer state should be saved.\n        \"\"\"\n        return \"optimizer\" in self.checkpoint_save_contents\n\n    @property\n    def should_save_extra(self) -> bool:\n        \"\"\"\n        Returns True if 'extra' is in checkpoint_save_contents, indicating the extra state should be saved.\n        \"\"\"\n        return \"extra\" in self.checkpoint_save_contents\n\n    @property\n    def should_save_hf_model(self) -> bool:\n        \"\"\"\n        Returns True if 'hf_model' is in checkpoint_save_contents, indicating the model should be converted to hf\n        model and saved.\n        \"\"\"\n        return \"hf_model\" in self.checkpoint_save_contents\n\n    @property\n    def should_load_model(self) -> bool:\n        \"\"\"\n        Returns True if 'model' is in checkpoint_load_contents, indicating the model state should be loaded.\n        \"\"\"\n        return \"model\" in self.checkpoint_load_contents\n\n    @property\n    def should_load_optimizer(self) -> bool:\n        \"\"\"\n        Returns True if 'optimizer' is in checkpoint_load_contents, indicating the optimizer state should be loaded.\n        \"\"\"\n        return \"optimizer\" in self.checkpoint_load_contents\n\n    @property\n    def should_load_extra(self) -> bool:\n        \"\"\"\n        Returns True if 'extra' is in checkpoint_load_contents, indicating the extra state should be loaded.\n        \"\"\"\n        return \"extra\" in self.checkpoint_load_contents\n\n    def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load: bool = False):\n        raise NotImplementedError\n\n    def save_checkpoint(\n        self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep: int = None\n    ):\n        raise NotImplementedError\n\n    @staticmethod\n    def checkpath(local_path: str, hdfs_path: str):\n        assert local_path is not None or hdfs_path is not None, \"local_path and hdfs_path cannot be both None\"\n        return local_path is not None, local_path if local_path is not None else hdfs_path\n\n    def remove_previous_save_local_path(self, path):\n        if isinstance(path, str):\n            path = [path]\n        for p in path:\n            abs_path = os.path.abspath(p)\n            if not os.path.exists(abs_path):\n                continue\n            global_step_path = Path(p).parent\n            delete_path = abs_path\n            if \"global_step_\" in str(global_step_path) and os.path.exists(global_step_path):\n                delete_path = global_step_path\n            logger.info(f\"Checkpoint manager remove previous save local path: {delete_path}\")\n            shutil.rmtree(delete_path, ignore_errors=True)\n\n    @staticmethod\n    def local_mkdir(path):\n        if not os.path.isabs(path):\n            working_dir = os.getcwd()\n            path = os.path.join(working_dir, path)\n\n        # Using hash value of path as lock file name to avoid long file name\n        lock_filename = f\"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock\"\n        lock_path = os.path.join(tempfile.gettempdir(), lock_filename)\n\n        try:\n            with FileLock(lock_path, timeout=60):  # Add timeout\n                # make a new dir\n                os.makedirs(path, exist_ok=True)\n        except Exception as e:\n            logger.warning(f\"Warning: Failed to acquire lock for {path}: {e}\")\n            # Even if the lock is not acquired, try to create the directory\n            os.makedirs(path, exist_ok=True)\n\n        return path\n\n    @staticmethod\n    def get_rng_state():\n        rng_state = {\n            \"cpu\": torch.get_rng_state(),\n            \"numpy\": np.random.get_state(),\n            \"random\": random.getstate(),\n        }\n\n        if is_cuda_available:\n            rng_state[\"cuda\"] = torch.cuda.get_rng_state()\n        elif is_npu_available:\n            rng_state[\"npu\"] = torch.npu.get_rng_state()\n\n        return rng_state\n\n    @staticmethod\n    def load_rng_state(rng_state):\n        torch.set_rng_state(rng_state[\"cpu\"])\n        np.random.set_state(rng_state[\"numpy\"])\n        random.setstate(rng_state[\"random\"])\n\n        if is_cuda_available:\n            torch.cuda.set_rng_state(rng_state[\"cuda\"])\n        elif is_npu_available:\n            torch.npu.set_rng_state(rng_state[\"npu\"])\n\n\ndef find_latest_ckpt_path(path, directory_format=\"global_step_{}\"):\n    \"\"\"\n    Return the most recent checkpoint directory based on a tracker file.\n\n    Args:\n        path (str): Base directory containing the checkpoint tracker.\n        directory_format (str): Template for checkpoint subfolders with one\n            placeholder for the iteration number (default \"global_step_{}\").\n\n    Returns:\n        str or None: Full path to the latest checkpoint directory, or\n        None if the tracker or checkpoint folder is missing.\n    \"\"\"\n    if path is None:\n        return None\n\n    tracker_file = get_checkpoint_tracker_filename(path)\n    if not os.path.exists(tracker_file):\n        logger.info(f\"Checkpoint tracker file does not exist: {tracker_file}\")\n        return None\n\n    with open(tracker_file, \"rb\") as f:\n        iteration = int(f.read().decode())\n    ckpt_path = os.path.join(path, directory_format.format(iteration))\n    if not os.path.exists(ckpt_path):\n        logger.info(f\"Checkpoint does not exist: {ckpt_path}\")\n        return None\n\n    logger.info(f\"Found checkpoint: {ckpt_path}\")\n    return ckpt_path\n\n\ndef get_checkpoint_tracker_filename(root_path: str):\n    \"\"\"\n    Tracker file rescords the latest checkpoint during training to restart from.\n    \"\"\"\n    return os.path.join(root_path, \"latest_checkpointed_iteration.txt\")\n"
  },
  {
    "path": "siirl/utils/checkpoint/fsdp_checkpoint_manager.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport warnings\nfrom typing import Optional, Union\n\nimport torch\nimport torch.distributed\nfrom accelerate import init_empty_weights\nfrom torch.distributed.fsdp import FullStateDictConfig, ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin\n\nfrom siirl.utils.extras.device import is_cuda_available\nfrom siirl.utils.extras.fs import copy_to_local, is_non_local\nfrom siirl.utils.model_utils.fsdp_utils import fsdp_version, get_fsdp_state_ctx\n\nfrom .checkpoint_manager import BaseCheckpointManager\n\n\nclass FSDPCheckpointManager(BaseCheckpointManager):\n    \"\"\"\n    Manage FSDP checkpointing in SPMD training.\n\n    - Saves/loads per-rank sharded model & optimizer states\n    - Persists full lr_scheduler and RNG state\n    - Stores HF tokenizer/processor and model/config for unified restore\n\n    Args:\n        model (FSDP): Wrapped model instance.\n        optimizer (Optimizer): Training optimizer.\n        lr_scheduler (LRScheduler): Learning-rate scheduler.\n        processing_class (PreTrainedTokenizer or ProcessorMixin, optional):\n            Pre-/post-processing artifact handler.\n        checkpoint_contents (list[str], optional):\n            Components to include; must contain 'model', 'optimizer', 'extra'.\n    \"\"\"\n\n    def __init__(\n        self,\n        model: FSDP,\n        optimizer: torch.optim.Optimizer,\n        lr_scheduler: torch.optim.lr_scheduler.LRScheduler,\n        processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None,\n        checkpoint_contents: Optional[list] = None,\n        **kwargs,\n    ):\n        if checkpoint_contents is None:\n            checkpoint_contents = [\"model\", \"optimizer\", \"extra\"]\n        if processing_class is None:\n            assert \"tokenizer\" in kwargs, \"tokenizer or processor must be provided\"\n            warnings.warn(\"`tokenizer` is deprecated. use `processing_class` instead.\", DeprecationWarning, stacklevel=2)\n            processing_class = kwargs.pop(\"tokenizer\")\n        assert \"model\" in checkpoint_contents and \"optimizer\" in checkpoint_contents and \"extra\" in checkpoint_contents, f\"FSDPCheckpointManager must include ['model', 'optimizer', 'extra'], got {checkpoint_contents}\"\n\n        super().__init__(\n            model,\n            optimizer,\n            lr_scheduler=lr_scheduler,\n            processing_class=processing_class,\n            checkpoint_contents=checkpoint_contents,\n        )\n\n    def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False):\n        \"\"\"\n        Load an FSDP checkpoint for this rank.\n\n        Downloads and loads:\n          - model and optimizer shards\n          - extra state dict (scheduler + RNG)\n\n        Args:\n            local_path: Directory with per-rank checkpoint files.\n            hdfs_path: Unused (for API compatibility).\n            del_local_after_load: Remove local files after loading.\n        \"\"\"\n        if local_path is None:\n            return\n\n        # every rank download its own checkpoint\n        remote_model_path = os.path.join(local_path, f\"model_world_size_{self.world_size}_rank_{self.rank}.pt\")\n        remote_optim_path = os.path.join(local_path, f\"optim_world_size_{self.world_size}_rank_{self.rank}.pt\")\n        remote_extra_state_path = os.path.join(local_path, f\"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt\")\n        print(f\"[rank-{self.rank}]: Loading from {remote_model_path} and {remote_optim_path} and {remote_extra_state_path}\")\n        local_model_path = copy_to_local(remote_model_path)\n        local_optim_path = copy_to_local(remote_optim_path)\n        local_extra_state_path = copy_to_local(remote_extra_state_path)\n\n        model_state_dict = torch.load(local_model_path, weights_only=False)\n        optimizer_state_dict = torch.load(local_optim_path, weights_only=False)\n        extra_state_dict = torch.load(local_extra_state_path, weights_only=False)\n\n        if del_local_after_load:\n            try:\n                os.remove(local_model_path) if is_non_local(local_model_path) else None\n                os.remove(local_optim_path) if is_non_local(local_optim_path) else None\n                os.remove(local_extra_state_path) if is_non_local(local_extra_state_path) else None\n            except Exception as e:\n                print(f\"[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored\")\n\n        lr_scheduler_state_dict = extra_state_dict[\"lr_scheduler\"]\n\n        state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False)\n        optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False)\n        with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):\n            self.model.load_state_dict(model_state_dict)\n            if self.optimizer is not None:\n                self.optimizer.load_state_dict(optimizer_state_dict)\n        # recover random state\n        if \"rng\" in extra_state_dict:\n            # 'rng' may not exist for backward compatibility\n            self.load_rng_state(extra_state_dict[\"rng\"])\n\n        if self.lr_scheduler is not None:\n            self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)\n\n    def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None):\n        \"\"\"\n        Save an FSDP checkpoint for this rank.\n\n        Writes:\n          - model & optimizer shard files\n          - extra state dict (scheduler + RNG)\n          - HF tokenizer/processor and model/config on rank 0\n          - optional full HF model under 'huggingface/' if requested\n\n        Rotates old checkpoints, keeping at most `max_ckpt_to_keep`.\n\n        Args:\n            local_path: Target directory for checkpoint files.\n            hdfs_path: Unused (for API compatibility).\n            global_step: Current training step (used for bookkeeping).\n            max_ckpt_to_keep: Number of recent checkpoints to retain.\n        \"\"\"\n        if local_path is None:\n            return\n\n        # record the previous global step\n        self.previous_global_step = global_step\n\n        # remove previous local_path\n        if max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len(self.previous_saved_paths) >= max_ckpt_to_keep:\n            keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1\n            self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start])\n            self.previous_saved_paths = self.previous_saved_paths[keep_start:]\n\n        local_path = self.local_mkdir(local_path)\n        torch.distributed.barrier()\n\n        # every rank will save its own model and optim shard\n        state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False)\n        optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False)\n        with warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):\n                model_state_dict = self.model.state_dict()\n                optimizer_state_dict = self.optimizer.state_dict() if self.optimizer is not None else None\n                lr_scheduler_state_dict = self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None\n\n                extra_state_dict = {\n                    \"lr_scheduler\": lr_scheduler_state_dict,\n                    \"rng\": self.get_rng_state(),\n                }\n                model_path = os.path.join(local_path, f\"model_world_size_{self.world_size}_rank_{self.rank}.pt\")\n                optim_path = os.path.join(local_path, f\"optim_world_size_{self.world_size}_rank_{self.rank}.pt\")\n                extra_path = os.path.join(local_path, f\"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt\")\n\n                print(f\"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}\")\n                print(f\"[rank-{self.rank}]: Saving optim to {os.path.abspath(optim_path)}\")\n                print(f\"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}\")\n                torch.save(model_state_dict, model_path)\n                torch.save(optimizer_state_dict, optim_path)  # TODO: address optimizer is None\n                torch.save(extra_state_dict, extra_path)\n\n        if self.rank == 0:\n            if fsdp_version(self.model) == 1:\n                unwrap_model = self.model._fsdp_wrapped_module\n            else:\n                unwrap_model = self.model\n\n            model_config = unwrap_model.config\n            if unwrap_model.can_generate() and hasattr(model_config, \"name_or_path\") and model_config.name_or_path:\n                # Some model's name_or_path is empty if not initialized from pretrained,\n                # in this cases, we don't save generation config.\n                generation_config = GenerationConfig.from_pretrained(model_config.name_or_path)\n                generation_config.save_pretrained(local_path)\n            else:\n                generation_config = None\n\n            model_config.save_pretrained(local_path)\n            self.processing_class.save_pretrained(local_path)\n\n        # wait for everyone to dump to local\n        torch.distributed.barrier()\n\n        if \"hf_model\" in self.checkpoint_contents:\n            hf_local_path = os.path.join(local_path, \"huggingface\")\n            os.makedirs(hf_local_path, exist_ok=True)\n\n            # Only rank 0 will save hf model and,\n            # offload to cpu to save LLMs which may be too large to fit in one GPU\n            state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)\n            with get_fsdp_state_ctx(self.model, StateDictType.FULL_STATE_DICT, state_dict_config, None):\n                state_dict = self.model.state_dict()\n\n            if self.rank == 0:\n                if \"ForTokenClassification\" in model_config.architectures[0]:\n                    from transformers import AutoModelForTokenClassification\n\n                    auto_model_cls = AutoModelForTokenClassification\n                elif \"ForCausalLM\" in model_config.architectures[0]:\n                    from transformers import AutoModelForCausalLM\n\n                    auto_model_cls = AutoModelForCausalLM\n                elif \"ForConditionalGeneration\" in model_config.architectures[0]:\n                    from transformers import AutoModelForVision2Seq\n\n                    auto_model_cls = AutoModelForVision2Seq\n                else:\n                    raise NotImplementedError(f\"Unknown architecture {model_config['architectures']}\")\n\n                with init_empty_weights():\n                    save_model = auto_model_cls.from_config(model_config, torch_dtype=torch.bfloat16)\n                save_model.to_empty(device=\"cpu\")\n\n                if save_model.can_generate():\n                    if generation_config is not None:\n                        save_model.generation_config = generation_config\n                    else:\n                        print(f\"Warning: {self.__class__.__name__}.save_checkpoint: Generation config file not found in, using a generation config created from the model config when saving hf_model.\")\n\n                save_model.save_pretrained(hf_local_path, state_dict=state_dict)\n                self.processing_class.save_pretrained(hf_local_path)\n                del state_dict\n                del save_model\n\n            # wait for rank0 to dump hf_model to local\n            torch.distributed.barrier()\n\n        self.previous_saved_paths.append(local_path)\n"
  },
  {
    "path": "siirl/utils/checkpoint/megatron_checkpoint_manager.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport random\nfrom collections.abc import Callable\nfrom dataclasses import asdict\n\nimport numpy as np\nimport torch\nimport torch.distributed\nfrom megatron.core import mpu, tensor_parallel\nfrom megatron.core.dist_checkpointing.mapping import ShardedObject\nfrom megatron.core.transformer.enums import AttnBackend\nfrom transformers import GenerationConfig\n\nfrom siirl.models.weight_loader_registry import get_weight_saver\nfrom siirl.utils.extras.device import get_device_name, get_torch_device\nfrom siirl.utils.extras.fs import is_non_local, local_mkdir_safe\nfrom siirl.utils.logger.aggregate_logger import log_with_rank\nfrom siirl.utils.megatron.megatron_utils import (\n    get_dist_checkpoint_path,\n    get_hf_model_checkpoint_path,\n    get_transformer_config_checkpoint_path,\n)\nfrom siirl.utils.megatron.dist_checkpointing import load_dist_checkpointing, save_dist_checkpointing\nfrom .checkpoint_manager import BaseCheckpointManager\n\nlogger = logging.getLogger(__name__)\n\nclass MegatronCheckpointManager(BaseCheckpointManager):\n    \"\"\"\n    Checkpoint manager for Megatron-LM distributed training.\n\n    This class manages the saving and loading of model checkpoints in a Megatron-LM\n    distributed training environment. It handles various aspects of checkpointing\n    including model states, optimizer states, learning rate schedulers, and random\n    number generator states, ensuring compatibility with HuggingFace formats.\n\n    Key features:\n    - Distributed checkpoint saving and loading using Megatron's dist_checkpointing\n    - Support for tensor parallel, pipeline parallel, and data parallel configurations\n    - Automatic handling of model state dictionaries across multiple pipeline stages\n    - Integration with HuggingFace model configurations and tokenizers\n    - Random number generator state management for reproducibility\n    - Support for both synchronous and asynchronous checkpoint operations\n\n    The manager automatically handles:\n    - Directory structure creation based on global steps and process ranks\n    - Model configuration and tokenizer saving in HuggingFace format\n    - Optimizer and scheduler state persistence\n    - CUDA RNG state management for deterministic training\n    - Checkpoint cleanup and retention policies\n\n    Args:\n        model: The Megatron model instance to checkpoint\n        optimizer: The optimizer instance (optional)\n        lr_scheduler: The learning rate scheduler instance (optional)\n\n    Attributes:\n        model: Reference to the Megatron model being checkpointed\n        optimizer: Reference to the optimizer (if provided)\n        lr_scheduler: Reference to the learning rate scheduler (if provided)\n        rank: Current process rank in the distributed setup\n\n    Example:\n        ```python\n        checkpoint_manager = MegatronCheckpointManager(\n            model=megatron_model,\n            optimizer=optimizer,\n            lr_scheduler=scheduler\n        )\n\n        checkpoint_manager.save_checkpoint(\n            local_path=\"checkpoints/step_1000\",\n            global_step=1000\n        )\n\n        checkpoint_manager.load_checkpoint(\n            local_path=\"checkpoints/step_1000\"\n        )\n        ```\n    \"\"\"\n\n    def __init__(\n        self,\n        config,\n        checkpoint_config,\n        model_config,\n        transformer_config,\n        role,\n        model: torch.nn.ModuleList,\n        arch: str,\n        hf_config,\n        param_dtype: torch.dtype,\n        share_embeddings_and_output_weights: bool,\n        processing_class,\n        optimizer,\n        optimizer_scheduler,\n        use_distributed_optimizer: bool,\n        use_checkpoint_opt_param_scheduler: bool = False,\n        use_dist_checkpointing: bool = True,\n        bridge=None,\n        **kwargs,\n    ):\n        super().__init__(\n            model,\n            optimizer=optimizer,\n            lr_scheduler=optimizer_scheduler,\n            processing_class=processing_class,\n            checkpoint_config=checkpoint_config,\n        )\n        self.arch = arch\n        self.config = config\n        self.transformer_config = transformer_config\n        self.role = role\n        self.is_value_model = False\n        if self.role in [\"reward\", \"critic\"]:\n            self.is_value_model = True\n        self.model_config = model_config\n        self.hf_config = hf_config\n        self.param_dtype = param_dtype\n        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights\n        self.model_path = self.config.model.path\n        self.use_distributed_optimizer = use_distributed_optimizer\n        self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler\n        self.bridge = bridge\n        self.rank = torch.distributed.get_rank()\n        self.use_dist_checkpointing = use_dist_checkpointing or not self.bridge or self.is_value_model\n        self.use_hf_checkpoint = not self.use_dist_checkpointing\n\n        self.weight_saver = get_weight_saver(self.arch)\n\n    def get_rng_state(self, use_dist_ckpt: bool = True, data_parallel_random_init: bool = False):\n        \"\"\"collect rng state across data parallel ranks\"\"\"\n        rng_state = {\n            \"random_rng_state\": random.getstate(),\n            \"np_rng_state\": np.random.get_state(),\n            \"torch_rng_state\": torch.get_rng_state(),\n            \"rng_tracker_states\": tensor_parallel.get_cuda_rng_tracker().get_states(),\n        }\n\n        if get_device_name() != \"cpu\":\n            rng_state[f\"{get_device_name()}_rng_state\"] = get_torch_device().get_rng_state()\n\n        rng_state_list = None\n        if torch.distributed.is_initialized() and mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init:\n            rng_state_list = [None for i in range(mpu.get_data_parallel_world_size())]\n            torch.distributed.all_gather_object(rng_state_list, rng_state, group=mpu.get_data_parallel_group())\n        else:\n            rng_state_list = [rng_state]\n\n        if use_dist_ckpt:\n            pp_rank = mpu.get_pipeline_model_parallel_rank()\n            pp_size = mpu.get_pipeline_model_parallel_world_size()\n            tp_rank = mpu.get_tensor_model_parallel_rank()\n            tp_size = mpu.get_tensor_model_parallel_world_size()\n            rng_state_list = ShardedObject(\n                \"rng_state\",\n                rng_state_list,\n                (pp_size, tp_size),\n                (pp_rank, tp_rank),\n                replica_id=mpu.get_data_parallel_rank(with_context_parallel=True),\n            )\n\n        return rng_state_list\n\n    def get_checkpoint_name(\n        self,\n        checkpoints_path,\n        pipeline_parallel=None,\n        tensor_rank=None,\n        pipeline_rank=None,\n        cp_rank=None,\n        expert_parallel=None,\n        expert_rank=None,\n        return_base_dir=True,\n        basename=\"model.pt\",\n    ):\n        \"\"\"Determine the directory name for this rank's checkpoint.\"\"\"\n        # Use both the tensor and pipeline MP rank.\n        if pipeline_parallel is None:\n            pipeline_parallel = mpu.get_pipeline_model_parallel_world_size() > 1\n        if tensor_rank is None:\n            tensor_rank = mpu.get_tensor_model_parallel_rank()\n        if pipeline_rank is None:\n            pipeline_rank = mpu.get_pipeline_model_parallel_rank()\n        if cp_rank is None:\n            cp_rank = mpu.get_context_parallel_rank()\n        if expert_parallel is None:\n            expert_parallel = mpu.get_expert_model_parallel_world_size() > 1\n        if expert_rank is None:\n            expert_rank = mpu.get_expert_model_parallel_rank()\n\n        # Use both the tensor and pipeline MP rank. If using the distributed\n        # optimizer, then the optimizer's path must additionally include the\n        # data parallel rank.\n\n        # due to the fact that models are identical across cp ranks, cp rank is not used in the checkpoint path\n        if not pipeline_parallel:\n            common_path = os.path.join(checkpoints_path, f\"mp_rank_{tensor_rank:02d}\")\n        else:\n            common_path = os.path.join(checkpoints_path, f\"mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}\")\n\n        if expert_parallel:\n            common_path = common_path + f\"_{expert_rank:03d}\"\n\n        os.makedirs(common_path, exist_ok=True)\n\n        if return_base_dir:\n            return common_path\n        return os.path.join(common_path, basename)\n\n    def generate_state_dict(\n        self, generate_model: bool = True, generate_optimizer: bool = True, generate_extra: bool = True\n    ):\n        # For save dist checkpointing\n        state_dict = {}\n\n        # Should always generate model state dict\n        # All ranks Save Model to reduce memory pressure\n        # Get sharded state dict, notice that state_dict will collect among dp groups, causing memory pressure\n        for vpp_rank, model in enumerate(self.model):\n            if len(self.model) > 1:\n                mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank)\n                key = f\"model{vpp_rank}\" if len(self.model) > 1 else \"model\"\n            else:\n                key = \"model\"\n            if hasattr(model, \"module\"):\n                model = model.module\n            state_dict[key] = model.sharded_state_dict()\n\n        # Optimizer State Dict\n        if generate_optimizer:\n            torch.distributed.barrier()\n            optimizer_sharded_states = self.optimizer.sharded_state_dict(state_dict)\n            state_dict[\"optimizer\"] = optimizer_sharded_states\n\n            if self.lr_scheduler is not None:\n                lr_state_dict = self.lr_scheduler.state_dict()\n                state_dict[\"lr_scheduler\"] = lr_state_dict\n\n        if not generate_model:\n            state_dict.pop(\"model\", None)\n\n        # RNG States State Dict\n        if generate_extra:\n            torch.distributed.barrier()\n            rng_state = self.get_rng_state()\n            state_dict[\"rng_state\"] = rng_state\n\n        return state_dict\n\n    def load_rng_states(self, rng_states, data_parallel_random_init=False, use_dist_ckpt=True):\n        # access rng_state for data parallel rank\n        if data_parallel_random_init:\n            rng_states = rng_states[mpu.get_data_parallel_rank()]\n        else:\n            rng_states = rng_states[0]\n        random.setstate(rng_states[\"random_rng_state\"])\n        np.random.set_state(rng_states[\"np_rng_state\"])\n        torch.set_rng_state(rng_states[\"torch_rng_state\"])\n\n        if get_device_name() != \"cpu\":\n            get_torch_device().set_rng_state(rng_states[f\"{get_device_name()}_rng_state\"])\n\n        # Check for empty states array\n        if not rng_states[\"rng_tracker_states\"]:\n            raise KeyError\n        tensor_parallel.get_cuda_rng_tracker().set_states(rng_states[\"rng_tracker_states\"])\n\n    def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False):\n        if local_path is not None:\n            assert os.path.exists(local_path), f\"Checkpoint path {local_path} does not exist.\"\n\n        dist_checkpoint_path = get_dist_checkpoint_path(local_path)\n\n        # Get State Dict for loading\n        sharded_state_dict = self.generate_state_dict(\n            self.should_load_model and self.use_dist_checkpointing, self.should_load_optimizer, self.should_load_extra\n        )\n        log_with_rank(f\"Generated state dict for loading: {sharded_state_dict.keys()}\", rank=self.rank, logger=logger)\n\n        # Load Dist Checkpointing\n        state_dict = load_dist_checkpointing(\n            sharded_state_dict=sharded_state_dict,\n            ckpt_dir=dist_checkpoint_path,\n        )\n\n        if self.should_load_model and self.use_dist_checkpointing:\n            assert \"model\" in state_dict or any(\n                f\"model{vpp_rank}\" in state_dict for vpp_rank in range(len(self.model))\n            ), f\"Model state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}.\"\n            for vpp_rank, model in enumerate(self.model):\n                if len(self.model) == 1:\n                    model_state_dict = state_dict[\"model\"]\n                else:\n                    assert f\"model{vpp_rank}\" in state_dict, f\"model{vpp_rank} not found in state_dict\"\n                    model_state_dict = state_dict[f\"model{vpp_rank}\"]\n                mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank)\n                self.model[vpp_rank].load_state_dict(model_state_dict)\n            log_with_rank(f\"Loaded sharded model checkpoint from {local_path}\", rank=self.rank, logger=logger)\n        elif self.should_load_model and self.use_hf_checkpoint:\n            hf_model_path = get_hf_model_checkpoint_path(local_path)\n            self.bridge.load_weights(self.model, hf_model_path)\n            log_with_rank(f\"Loaded HF model checkpoint from {hf_model_path} with bridge\", rank=self.rank, logger=logger)\n\n        if self.should_load_optimizer:\n            assert \"optimizer\" in state_dict, (\n                f\"Optimizer state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}.\"\n            )\n            optimizer_state_dict = state_dict[\"optimizer\"]\n            self.optimizer.load_state_dict(optimizer_state_dict)\n            log_with_rank(f\"Loaded optimizer checkpoint from {local_path}\", rank=self.rank, logger=logger)\n            if self.use_checkpoint_opt_param_scheduler:\n                assert \"lr_scheduler\" in state_dict, (\n                    f\"LR scheduler state dict not found in {state_dict.keys()}. Please check the checkpoint file \"\n                    f\"{local_path}.\"\n                )\n                lr_scheduler_state_dict = state_dict[\"lr_scheduler\"]\n                if self.lr_scheduler is not None:\n                    self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)\n                    log_with_rank(f\"Loaded LR scheduler checkpoint from {local_path}\", rank=self.rank, logger=logger)\n\n        if self.should_load_extra:\n            assert \"rng_state\" in state_dict, (\n                f\"RNG state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}.\"\n            )\n            rng_state = state_dict[\"rng_state\"]\n            self.load_rng_states(rng_state)\n            log_with_rank(f\"Loaded RNG states from {local_path}\", rank=self.rank, logger=logger)\n\n        if del_local_after_load:\n            try:\n                os.remove(local_path) if is_non_local(local_path) else None\n            except Exception as e:\n                log_with_rank(\n                    f\"remove local resume ckpt file after loading failed, exception {e} will be ignored\",\n                    rank=self.rank,\n                    logger=logger,\n                )\n\n    def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None):\n        # record the previous global step\n        self.previous_global_step = global_step\n\n        # remove previous local_path\n        if (\n            max_ckpt_to_keep\n            and isinstance(max_ckpt_to_keep, int)\n            and max_ckpt_to_keep > 0\n            and len(self.previous_saved_paths) >= max_ckpt_to_keep\n        ):\n            keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1\n            self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start])\n            self.previous_saved_paths = self.previous_saved_paths[keep_start:]\n\n        local_path = local_mkdir_safe(local_path)\n        dist_checkpoint_path = get_dist_checkpoint_path(local_path)\n\n        # Note that model weights, optimizer states, and extra states are generated\n        # together in a state dict, we save them in one time\n        if self.use_dist_checkpointing:\n            # Generate state dict for saving\n            state_dict = self.generate_state_dict(\n                self.should_save_model, self.should_save_optimizer, self.should_save_extra\n            )\n            log_with_rank(f\"Generated state dict for saving: {state_dict.keys()}\", rank=self.rank, logger=logger)\n            for vpp_rank, model in enumerate(self.model):\n                if len(self.model) > 1:\n                    model_i_keys = state_dict[f\"model{vpp_rank}\"].keys()\n                    log_with_rank(f\"Generated state dict for saving: {model_i_keys}\", rank=self.rank, logger=logger)\n                else:\n                    log_with_rank(\n                        f\"Generated state dict for saving: {state_dict['model'].keys()}\", rank=self.rank, logger=logger\n                    )\n            # Start Async save if enabled\n            async_save_request = save_dist_checkpointing(\n                sharded_state_dict=state_dict,\n                ckpt_path=dist_checkpoint_path,\n                async_save=self.checkpoint_config.async_save,\n            )\n\n            # Synchronize all async save requests\n            if not self.checkpoint_config.async_save:\n                assert async_save_request is None, \"Async save request should be None when not using async save.\"\n                torch.distributed.barrier()\n        else:\n            assert self.use_hf_checkpoint, \"When not using distributed checkpointing, use_hf_checkpoint should be True.\"\n            # Generate optimizer and exra state dicts\n            state_dict = self.generate_state_dict(\n                generate_model=False,\n                generate_optimizer=self.should_save_optimizer,\n                generate_extra=self.should_save_extra,\n            )\n            # Save optimizer and extra states to local path\n            # Start Async save if enabled\n            async_save_request = save_dist_checkpointing(\n                sharded_state_dict=state_dict,\n                ckpt_path=dist_checkpoint_path,\n                async_save=self.checkpoint_config.async_save,\n            )\n\n            # Synchronize all async save requests\n            if not self.checkpoint_config.async_save:\n                assert async_save_request is None, \"Async save request should be None when not using async save.\"\n                torch.distributed.barrier()\n\n        if self.should_save_model:\n            if self.use_hf_checkpoint:\n                # Use mbridge to save HF model checkpoint\n                log_with_rank(f\"Saving HF model checkpoint to {local_path} with bridge\", rank=self.rank, logger=logger)\n                hf_ckpt_path = get_hf_model_checkpoint_path(local_path)\n                self.bridge.save_weights(self.model, hf_ckpt_path)\n                log_with_rank(f\"Saved bridge checkpoint to {hf_ckpt_path}\", rank=self.rank, logger=logger)\n\n            # Only rank 0 saves the hf config and tokenizer to huggingface path\n            # No matter whether we save hf model or not\n            if self.rank == 0:\n                # Save tokenizer\n                hf_config_tokenizer_path = get_hf_model_checkpoint_path(local_path)\n                self.processing_class.save_pretrained(hf_config_tokenizer_path)\n                # Save huggingface config\n                self.hf_config.save_pretrained(hf_config_tokenizer_path)\n                if hasattr(self.hf_config, \"name_or_path\") and self.hf_config.name_or_path:\n                    try:\n                        generation_config = GenerationConfig.from_pretrained(self.hf_config.name_or_path)\n                        generation_config.save_pretrained(hf_config_tokenizer_path)\n                    except Exception:\n                        # if the generation config isn't available, we don't save it\n                        pass\n                log_with_rank(\n                    f\"Saved Huggingface config and tokenizer to {hf_config_tokenizer_path}\",\n                    rank=self.rank,\n                    logger=logger,\n                    log_only_rank_0=True,\n                )\n\n        if self.should_save_extra:\n            if self.rank == 0:\n                # Save transformer config\n                print(self.transformer_config)\n                transformer_config_dict = asdict(self.transformer_config)\n                to_convert_types = {torch.dtype: str, AttnBackend: str}\n                ignore_types = [Callable]\n                pop_keys = []\n                for key, value in transformer_config_dict.items():\n                    if type(value) in to_convert_types:\n                        transformer_config_dict[key] = to_convert_types[type(value)](value)\n                    if type(value) in ignore_types:\n                        pop_keys.append(key)\n                    if callable(value):\n                        pop_keys.append(key)\n                for key in pop_keys:\n                    transformer_config_dict.pop(key)\n                transformer_config_path = get_transformer_config_checkpoint_path(local_path)\n                with open(transformer_config_path, \"w\") as f:\n                    json.dump(transformer_config_dict, f, indent=2)\n\n        if self.should_save_hf_model and not self.use_hf_checkpoint:\n            # wait for everyone to dump to local\n            state_dict = self.weight_saver(\n                self.model,\n                self.hf_config,\n                dtype=self.param_dtype,\n                is_value_model=self.is_value_model,\n                tie_word_embeddings=self.share_embeddings_and_output_weights,\n            )\n\n            torch.distributed.barrier()\n            if self.rank == 0:\n                hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)\n                import warnings\n\n                from accelerate import init_empty_weights\n\n                with init_empty_weights(), warnings.catch_warnings():\n                    warnings.simplefilter(\"ignore\")\n                    if \"mistral7b-rm\" in self.config.model.path:\n                        from transformers import MistralForSequenceClassification\n\n                        model = MistralForSequenceClassification.from_pretrained(\n                            self.config.model.path\n                        )  # use score head instead of lm_head\n                        state_dict[\"score.weight\"] = state_dict[\"score.weight\"]\n                    else:\n                        from transformers import AutoModelForCausalLM\n\n                        model = AutoModelForCausalLM.from_pretrained(self.config.model.path, torch_dtype=\"auto\")\n                model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)\n                log_with_rank(\n                    f\"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}\",\n                    rank=self.rank,\n                    logger=logger,\n                    log_only_rank_0=True,\n                )\n\n                if hdfs_path is not None:\n                    log_with_rank(\n                        f\"Uploading checkpoint to {hdfs_path}\", rank=self.rank, logger=logger, log_only_rank_0=True\n                    )\n                    from siirl.utils.extras import hdfs_io\n\n                    hdfs_io.makedirs(hdfs_path, exist_ok=True)\n                    hdfs_io.copy(src=hf_model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True)\n                    log_with_rank(\n                        f\"HDFS checkpoint uploaded to {hdfs_path}\", rank=self.rank, logger=logger, log_only_rank_0=True\n                    )\n\n        def finalize_save_fn():\n            # Rank 0 uploads checkpoint to HDFS if hdfs_path is provided\n            log_with_rank(\n                f\"Dist checkpointing save completed for {dist_checkpoint_path}\", rank=self.rank, logger=logger\n            )\n            if self.rank == 0:\n                if hdfs_path is not None:\n                    log_with_rank(f\"Uploading checkpoint to {hdfs_path}\", rank=self.rank, logger=logger)\n                    from siirl.utils.extras import hdfs_io\n\n                    hdfs_io.makedirs(hdfs_path, exist_ok=True)\n                    hdfs_io.copy(src=dist_checkpoint_path, dst=hdfs_path, dirs_exist_ok=True)\n                    hdfs_io.copy(src=hf_config_tokenizer_path, dst=hdfs_path, dirs_exist_ok=True)\n\n        if self.checkpoint_config.async_save:\n            assert async_save_request is not None, \"Async save request should not be None when using async save.\"\n            async_save_request.add_finalize_fn(finalize_save_fn)\n        else:\n            finalize_save_fn()\n\n        self.previous_saved_paths.append(local_path)\n"
  },
  {
    "path": "siirl/utils/debug/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .performance import GPUMemoryLogger, log_gpu_memory_usage\nfrom siirl.utils.extras.import_utils import is_nvtx_available\nfrom siirl.utils.extras.device import is_npu_available\nfrom siirl.utils.debug.profile import DistProfiler\n\nif is_nvtx_available():\n    from .profile import DistProfiler, mark_annotate, mark_end_range, mark_start_range\nelif is_npu_available:\n    from .mstx_profile import NPUProfiler as DistProfiler\n    from .mstx_profile import mark_annotate, mark_end_range, mark_start_range\nelse:\n    from .profile import DistProfiler, mark_annotate, mark_end_range, mark_start_range\n\n__all__ = [\n    \"GPUMemoryLogger\",\n    \"log_gpu_memory_usage\",\n    \"DistProfiler\",\n    \"mark_annotate\",\n    \"mark_end_range\",\n    \"mark_start_range\",\n]\n"
  },
  {
    "path": "siirl/utils/debug/mstx_profile.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Inspired from https://gitee.com/ascend/MindSpeed-RL/blob/master/mindspeed_rl/utils/utils.py\nimport functools\nimport logging\nimport os\nfrom contextlib import contextmanager\nfrom typing import Any, Callable, Optional\n\nimport torch_npu\nfrom torch_npu.npu import mstx\nfrom siirl.params import ProfilerArguments\nfrom loguru import logger\n\nfrom .profile import DistProfiler\n\n\ndef mark_start_range(message: Optional[str] = None) -> None:\n    \"\"\"Start a mark range in the profiler.\n\n    Args:\n        message (str, optional):\n            The message to be displayed in the profiler. Defaults to None.\n    \"\"\"\n    return mstx.range_start(message=message)\n\n\ndef mark_end_range(range_id: str) -> None:\n    \"\"\"End a mark range in the profiler.\n\n    Args:\n        range_id (str):\n            The id of the mark range to end.\n    \"\"\"\n    return mstx.range_end(range_id)\n\n\ndef mark_annotate(message: Optional[str] = None) -> Callable:\n    \"\"\"Decorate a function to annotate a mark range along with the function life cycle.\n\n    Args:\n        message (str, optional):\n            The message to be displayed in the profiler. Defaults to None.\n    \"\"\"\n\n    def decorator(func):\n        profile_message = message or func.__name__\n        return mstx.mstx_range(profile_message)(func)\n\n    return decorator\n\n\n@contextmanager\ndef marked_timer(name: str, timing_raw: dict[str, float], *args: Any, **kwargs: Any) -> None:\n    \"\"\"Context manager for timing with MSTX markers.\n\n    This utility function measures the execution time of code within its context,\n    accumulates the timing information, and adds MSTX markers for profiling.\n\n    Args:\n        name (str): The name/identifier for this timing measurement.\n        timing_raw (Dict[str, float]): Dictionary to store timing information.\n\n    Yields:\n        None: This is a context manager that yields control back to the code block.\n    \"\"\"\n    if args:\n        logging.warning(f\"Args are not supported in mstx_profile, but received: {args}\")\n    if kwargs:\n        logging.warning(f\"Kwargs are not supported in mstx_profile, but received: {kwargs}\")\n    mark_range = mark_start_range(message=name)\n    from .performance import _timer\n\n    yield from _timer(name, timing_raw)\n    mark_end_range(mark_range)\n\n\ndef get_npu_profiler(config: ProfilerArguments, role: Optional[str] = None, profile_step: Optional[str] = None):\n    \"\"\"Generate and return an NPU profiler object.\n    \"\"\"\n    if config.level == \"level_none\":\n        profile_level = torch_npu.profiler.ProfilerLevel.Level_none\n    elif config.level == \"level0\":\n        profile_level = torch_npu.profiler.ProfilerLevel.Level0\n    elif config.level == \"level1\":\n        profile_level = torch_npu.profiler.ProfilerLevel.Level1\n    elif config.level == \"level2\":\n        profile_level = torch_npu.profiler.ProfilerLevel.Level2\n    else:\n        raise ValueError(f\"level only supports level0, 1, 2, and level_none, but gets {config.level}\")\n\n    profile_save_path = config.save_path\n    if profile_step:\n        profile_save_path = os.path.join(profile_save_path, profile_step)\n    if role:\n        profile_save_path = os.path.join(profile_save_path, role)\n\n    experimental_config = torch_npu.profiler._ExperimentalConfig(\n        aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,\n        profiler_level=profile_level,\n        export_type=torch_npu.profiler.ExportType.Text,\n        data_simplification=True,\n        msprof_tx=True,\n    )\n\n    activites = []\n    if config.with_npu:\n        activites.append(torch_npu.profiler.ProfilerActivity.NPU)\n    if config.with_cpu:\n        activites.append(torch_npu.profiler.ProfilerActivity.CPU)\n\n    prof = torch_npu.profiler.profile(\n        with_modules=config.with_module,\n        with_stack=config.with_stack,\n        record_shapes=config.record_shapes,\n        profile_memory=config.with_memory,\n        activities=activites,\n        on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(profile_save_path, analyse_flag=config.analysis),\n        experimental_config=experimental_config,\n    )\n    return prof\n\n\nclass NPUProfiler(DistProfiler):\n    \"\"\"\n    NPU profiler. Initialized in a worker to control the NPU profiler.\n    \"\"\"\n\n    _define_count = 0\n\n    def __init__(self, rank: int, config: ProfilerArguments, **kwargs):\n        self.switch: bool = False\n        self.config = config\n        self.match_rank: bool = True if config.all_ranks else rank in config.ranks\n        self.profile_npu = None\n\n    def start(self, **kwargs):\n        role, profile_step = kwargs.get(\"role\", None), kwargs.get(\"profile_step\", None)\n        profile_step = str(profile_step) if profile_step is not None else None\n        if self.match_rank and self.config.enable:\n            self.switch = True\n            if not self.config.discrete and NPUProfiler._define_count == 0:\n                self.profile_npu = get_npu_profiler(config=self.config, role=role, profile_step=profile_step)\n                self.profile_npu.start()\n                NPUProfiler._define_count += 1\n\n    def stop(self):\n        if self.match_rank and self.config.enable:\n            self.switch = False\n            if not self.config.discrete and NPUProfiler._define_count == 1:\n                self.profile_npu.step()\n                self.profile_npu.stop()\n                NPUProfiler._define_count -= 1\n\n    @staticmethod\n    def annotate(message: Optional[str] = None, role: Optional[str] = None, **kwargs) -> Callable:\n        \"\"\"Decorate a Worker member function to profile the current rank in the current training step.\n\n        Requires the target function to be a member function of a Worker,\n        which has a member field `profiler` with NPUProfiler type.\n\n        Args:\n            message (str, optional):\n                The message to be displayed in the profiler. Defaults to None.\n            role (str, optional):\n                The role of the current data collection. Defaults to None.\n        \"\"\"\n\n        def decorator(func):\n            @functools.wraps(func)\n            def wrapper(self, *args, **kwargs):\n                profile_name = message or func.__name__\n                match_role = True\n                discrete_mode = self._profiler.config.discrete\n                profile_enable = self._profiler.switch\n\n                if not profile_enable:\n                    return func(self, *args, **kwargs)\n\n                if profile_enable and role is not None:\n                    target_roles = self._profiler.config.roles\n                    match_role = True if not discrete_mode else role in target_roles\n\n                if profile_enable:\n                    if not discrete_mode:\n                        mark_range = mark_start_range(message=profile_name)\n                    else:\n                        if match_role:\n                            profile_npu = get_npu_profiler(config=self._profiler.config, role=role)\n                            profile_npu.start()\n                            mark_range = mark_start_range(message=profile_name)\n\n                result = func(self, *args, **kwargs)\n\n                if profile_enable:\n                    if not discrete_mode:\n                        mark_end_range(mark_range)\n                    else:\n                        if match_role:\n                            mark_end_range(mark_range)\n                            profile_npu.step()\n                            profile_npu.stop()\n\n                return result\n\n            return wrapper\n\n        return decorator\n"
  },
  {
    "path": "siirl/utils/debug/performance.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport datetime\nimport inspect\nimport logging\nfrom typing import Any, Tuple\n\nimport torch.distributed as dist\n\nfrom siirl.utils.extras.device import get_device_id, get_torch_device\nfrom siirl.utils.logger.aggregate_logger import DecoratorLoggerBase\n\n\ndef _get_current_mem_info(unit: str = \"GB\", precision: int = 2) -> Tuple[str]:\n    \"\"\"Get current memory usage.\"\"\"\n    assert unit in [\"GB\", \"MB\", \"KB\"]\n    divisor = 1024**3 if unit == \"GB\" else 1024**2 if unit == \"MB\" else 1024\n    mem_allocated = get_torch_device().memory_allocated()\n    mem_reserved = get_torch_device().memory_reserved()\n    # use get_torch_device().mem_get_info to profile device memory\n    # since vllm's sleep mode works below pytorch\n    # see https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119\n    mem_free, mem_total = get_torch_device().mem_get_info()\n    mem_used = mem_total - mem_free\n    mem_allocated = f\"{mem_allocated / divisor:.{precision}f}\"\n    mem_reserved = f\"{mem_reserved / divisor:.{precision}f}\"\n    mem_used = f\"{mem_used / divisor:.{precision}f}\"\n    mem_total = f\"{mem_total / divisor:.{precision}f}\"\n    return mem_allocated, mem_reserved, mem_used, mem_total\n\n\ndef log_gpu_memory_usage(head: str, logger=None, level=\"DEBUG\", rank: int = 0):\n    if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank):\n        mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info()\n        message = f\"{head}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, device memory used/total (GB): {mem_used}/{mem_total}\"\n\n        if logger is None:\n            print(message)\n        else:\n            logger.log(level, message)\n\n\nclass GPUMemoryLogger(DecoratorLoggerBase):\n    \"\"\"A decorator class to log GPU memory usage.\n    Example:\n        >>> from siirl.utils.debug.performance import GPUMemoryLogger\n        >>> @GPUMemoryLogger(role=\"actor\")\n        >>> def update_actor(self, batch):\n        ...     # real actor update logics\n        ...     return\n    \"\"\"\n\n    def __init__(self, role: str, logger: logging.Logger = None, level=logging.DEBUG, log_only_rank_0: bool = True):\n        if dist.is_initialized() and dist.get_world_size() > 1:\n            rank = dist.get_rank()\n        else:\n            rank = 0\n        super().__init__(role, logger, level, rank, log_only_rank_0)\n\n    def __call__(self, decorated_function: callable):\n        def f(*args, **kwargs):\n            return self.log(decorated_function, *args, **kwargs)\n\n        return f\n\n    def log(self, func, *args, **kwargs):\n        name = func.__name__\n        mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info()\n        message = f\"Before {name}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, device memory used/total (GB): {mem_used}/{mem_total}\"\n        self.logging_function(message)\n\n        output = func(*args, **kwargs)\n\n        mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info()\n        message = f\"After {name}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, device memory used/total (GB): {mem_used}/{mem_total}\"\n\n        self.logging_function(message)\n        return output\n\n\ndef log_print(ctn: Any):\n    current_time = datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n\n    frame = inspect.currentframe().f_back\n    function_name = frame.f_code.co_name\n    line_number = frame.f_lineno\n    file_name = frame.f_code.co_filename.split(\"/\")[-1]\n    print(f\"[{file_name}:{line_number}:{function_name}]: {ctn}\")\n"
  },
  {
    "path": "siirl/utils/debug/profile.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom typing import Callable, Optional\n\nimport torch\nimport torch.distributed\n\nfrom siirl.params import ProfilerArguments\nfrom siirl.utils.extras.import_utils import is_nvtx_available\nfrom loguru import logger\n\n\nclass Profiler:\n    def __init__(self, config):\n        # note : if we do not set use_profile, it will be set as None, so that all function will be skip\n        self.config = config\n        self.skip_prof = False\n        self.saved = False\n        self.prof = None\n        self.rank = torch.distributed.get_rank()\n        # we need to validate the config before using the profiler\n        self._validate()\n        if config.use_profile and self.rank in self.config.profile_ranks:\n            print(f\"[Profiler] Profiler init for rank {self.rank}\")\n\n            self.prof = torch.profiler.profile(\n                activities=[\n                    torch.profiler.ProfilerActivity.CPU,\n                    torch.profiler.ProfilerActivity.CUDA,\n                ],\n                schedule=torch.profiler.schedule(\n                    wait=max(self.config.step_start - 1, 0),\n                    warmup=1 if self.config.step_start > 0 else 0,\n                    active=self.config.step_end - self.config.step_start,\n                    repeat=1,\n                ),\n                record_shapes=True,\n                with_stack=True,\n            )\n\n    def _validate(self):\n        if self.config.use_profile:\n            if self.config.profile_ranks is None:\n                print(\"[WARNING] Profile ranks is not set, default to rank 0\")\n                self.config.profile_ranks = [0]\n            assert self.config.step_start >= 0, \"[ERROR] Profile step start must be greater than 0\"\n            assert self.config.step_end >= 0, \"[ERROR] Profile step end must be greater than 0\"\n            assert self.config.step_start < self.config.step_end, \"[ERROR] Profile step start must be less than step end\"\n\n    def check(self):\n        return self.prof is not None and not self.skip_prof\n\n    def start(self):\n        if self.check():\n            print(f\"[Profiler] started for rank {self.rank}\")\n            self.prof.start()\n\n    def step(self):\n        if self.check():\n            self.prof.step()\n\n    def stop(self):\n        if self.check():\n            print(f\"[Profiler] stopped for rank {self.rank}\")\n            self.prof.stop()\n\n    def save(self):\n        if self.prof is not None and not self.saved:\n            if not os.path.exists(self.config.save_path):\n                os.makedirs(self.config.save_path)\n            save_file_name = f\"/prof_start_{self.config.step_start}_end_{self.config.step_end}_rank_{self.rank}.json\"\n            print(f\"[Profiler] Saving trace to {self.config.save_path + save_file_name}\")\n            self.prof.export_chrome_trace(self.config.save_path + save_file_name)\n            self.skip_prof = True\n            self.saved = True\n\n    def stop_and_save(self):\n        if self.check():\n            self.stop()\n            self.save()\n\n    def stop_trace(self):\n        if self.check():\n            print(f\"[Profiler] Trace stopped for rank {self.rank}\")\n            self.skip_prof = True\n\n\ndef mark_start_range(\n    message: Optional[str] = None,\n    color: Optional[str] = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n) -> None:\n    \"\"\"Start a profiling range marker (no-op implementation).\n\n    Args:\n        message (Optional[str]): Message to associate with the range marker.\n        color (Optional[str]): Color for the marker visualization.\n        domain (Optional[str]): Domain for the marker.\n        category (Optional[str]): Category for the marker.\n    \"\"\"\n    pass\n\n\ndef mark_end_range(range_id: str) -> None:\n    \"\"\"End a profiling range marker (no-op implementation).\n\n    Args:\n        range_id (str): Identifier of the range to end.\n    \"\"\"\n    pass\n\n\ndef mark_annotate(\n    message: Optional[str] = None,\n    color: Optional[str] = None,\n    domain: Optional[str] = None,\n    category: Optional[str] = None,\n) -> Callable:\n    \"\"\"Decorator to annotate a function with profiling markers (no-op implementation).\n\n    Args:\n        message (Optional[str]): Message to associate with the annotation.\n        color (Optional[str]): Color for the marker visualization.\n        domain (Optional[str]): Domain for the marker.\n        category (Optional[str]): Category for the marker.\n\n    Returns:\n        Callable: Decorator function that returns the original function unchanged.\n    \"\"\"\n\n    def decorator(func):\n        return func\n\n    return decorator\n\n\nclass DistProfiler:\n\n    def __init__(self, rank: int, config: ProfilerArguments, **kwargs):\n        self.config = config\n        if self.config.enable and is_nvtx_available():\n            self.config.enable = False\n            logger.error(\"!!!!!!!!!!!!!!!Currently only support NPU profiling.!!!!!!!!!!!!!!!\")\n\n    def start(self, **kwargs):\n        pass\n\n    def stop(self):\n        pass\n\n    @staticmethod\n    def annotate(\n        message: Optional[str] = None,\n        color: Optional[str] = None,\n        domain: Optional[str] = None,\n        category: Optional[str] = None,\n        **kwargs,\n    ) -> Callable:\n        def decorator(func):\n            return func\n\n        return decorator\n"
  },
  {
    "path": "siirl/utils/embodied/__init__.py",
    "content": ""
  },
  {
    "path": "siirl/utils/embodied/libero_utils.py",
    "content": "\"\"\"Utils for evaluating policies in LIBERO simulation environments.\"\"\"\n\nimport math\nimport os\n\nimport imageio\nimport numpy as np\nimport tensorflow as tf\nfrom libero.libero import get_libero_path\nfrom libero.libero.envs import OffScreenRenderEnv\nimport random\n# from experiments.robot.robot_utils import (\n#     DATE,\n#     DATE_TIME,\n# )\n\n\ndef get_libero_env(task, model_family, gpu_id=-1, resolution=256):\n    \"\"\"Initializes and returns the LIBERO environment, along with the task description.\"\"\"\n    task_description = task.language\n    task_bddl_file = os.path.join(get_libero_path(\"bddl_files\"), task.problem_folder, task.bddl_file)\n    env_args = {\n        \"bddl_file_name\": task_bddl_file,\n        \"camera_heights\": resolution,\n        \"camera_widths\": resolution,\n        \"render_gpu_device_id\": gpu_id\n    }\n    env = OffScreenRenderEnv(**env_args)\n    env.seed(0)  # IMPORTANT: seed seems to affect object positions even when using fixed initial state\n    return env, task_description\n\n\ndef get_libero_dummy_action(model_family: str):\n    \"\"\"Get dummy/no-op action, used to roll out the simulation while the robot does nothing.\"\"\"\n    return [0, 0, 0, 0, 0, 0, -1]\n\n\ndef resize_image(img, resize_size):\n    \"\"\"\n    Takes numpy array corresponding to a single image and returns resized image as numpy array.\n\n    NOTE (Moo Jin): To make input images in distribution with respect to the inputs seen at training time, we follow\n                    the same resizing scheme used in the Octo dataloader, which OpenVLA uses for training.\n    \"\"\"\n\n    assert isinstance(resize_size, tuple)\n    # Resize to image size expected by model\n    img = tf.image.encode_jpeg(img)  # Encode as JPEG, as done in RLDS dataset builder\n    img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8)  # Immediately decode back\n    img = tf.image.resize(img, resize_size, method=\"lanczos3\", antialias=True)\n    img = tf.cast(tf.clip_by_value(tf.round(img), 0, 255), tf.uint8)\n    img = img.numpy()\n    return img\n\n\ndef get_libero_image(obs, resize_size):\n    \"\"\"Extracts image from observations and preprocesses it.\"\"\"\n    assert isinstance(resize_size, int) or isinstance(resize_size, tuple)\n    if isinstance(resize_size, int):\n        resize_size = (resize_size, resize_size)\n    img = obs[\"agentview_image\"]\n    img = img[::-1, ::-1]  # IMPORTANT: rotate 180 degrees to match train preprocessing\n    img = resize_image(img, resize_size)\n    return img\n\n\ndef get_libero_wrist_image(obs, resize_size):\n    \"\"\"Extracts image from observations and preprocesses it.\"\"\"\n    assert isinstance(resize_size, int) or isinstance(resize_size, tuple)\n    if isinstance(resize_size, int):\n        resize_size = (resize_size, resize_size)\n    img = obs[\"robot0_eye_in_hand_image\"]\n    img = img[::-1, ::-1]  # IMPORTANT: rotate 180 degrees to match train preprocessing\n    img = resize_image(img, resize_size)\n    return img\n\n# def save_rollout_video(rollout_images, idx, success, task_description, log_file=None):\n#     \"\"\"Saves an MP4 replay of an episode.\"\"\"\n#     rollout_dir = f\"./rollouts/{DATE}\"\n#     os.makedirs(rollout_dir, exist_ok=True)\n#     processed_task_description = task_description.lower().replace(\" \", \"_\").replace(\"\\n\", \"_\").replace(\".\", \"_\")[:50]\n#     mp4_path = f\"{rollout_dir}/{DATE_TIME}--episode={idx}--success={success}--task={processed_task_description}.mp4\"\n#     video_writer = imageio.get_writer(mp4_path, fps=30)\n#     for img in rollout_images:\n#         video_writer.append_data(img)\n#     video_writer.close()\n#     print(f\"Saved rollout MP4 at path {mp4_path}\")\n#     if log_file is not None:\n#         log_file.write(f\"Saved rollout MP4 at path {mp4_path}\\n\")\n#     return mp4_path\n\n\ndef quat2axisangle(quat):\n    \"\"\"\n    Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55\n\n    Converts quaternion to axis-angle format.\n    Returns a unit vector direction scaled by its angle in radians.\n\n    Args:\n        quat (np.array): (x,y,z,w) vec4 float angles\n\n    Returns:\n        np.array: (ax,ay,az) axis-angle exponential coordinates\n    \"\"\"\n    # clip quaternion\n    if quat[3] > 1.0:\n        quat[3] = 1.0\n    elif quat[3] < -1.0:\n        quat[3] = -1.0\n\n    den = np.sqrt(1.0 - quat[3] * quat[3])\n    if math.isclose(den, 0.0):\n        # This is (close to) a zero degree rotation, immediately return\n        return np.zeros(3)\n\n    return (quat[:3] * 2.0 * math.acos(quat[3])) / den\n\ndef get_image_resize_size(cfg):\n    \"\"\"\n    Gets image resize size for a model class.\n    If `resize_size` is an int, then the resized image will be a square.\n    Else, the image will be a rectangle.\n    \"\"\"\n    if cfg.model_family == \"openvla\":\n        resize_size = 224\n    else:\n        raise ValueError(\"Unexpected `model_family` found in config.\")\n    return resize_size\n\n# def normalize_gripper_action(action, binarize=True):\n#     \"\"\"\n#     Changes gripper action (last dimension of action vector) from [0,1] to [-1,+1].\n#     Necessary for some environments (not Bridge) because the dataset wrapper standardizes gripper actions to [0,1].\n#     Note that unlike the other action dimensions, the gripper action is not normalized to [-1,+1] by default by\n#     the dataset wrapper.\n\n#     Normalization formula: y = 2 * (x - orig_low) / (orig_high - orig_low) - 1\n#     \"\"\"\n#     # Just normalize the last action to [-1,+1].\n#     orig_low, orig_high = 0.0, 1.0\n#     action[..., -1] = 2 * (action[..., -1] - orig_low) / (orig_high - orig_low) - 1\n\n#     if binarize:\n#         # Binarize to -1 or +1.\n#         action[..., -1] = np.sign(action[..., -1])\n\n#     return action\n\ndef normalize_gripper_action(action: np.ndarray, binarize: bool = True) -> np.ndarray:\n    \"\"\"\n    Normalize gripper action from [0,1] to [-1,+1] range.\n\n    This is necessary for some environments because the dataset wrapper\n    standardizes gripper actions to [0,1]. Note that unlike the other action\n    dimensions, the gripper action is not normalized to [-1,+1] by default.\n\n    Normalization formula: y = 2 * (x - orig_low) / (orig_high - orig_low) - 1\n\n    Args:\n        action: Action array with gripper action in the last dimension\n        binarize: Whether to binarize gripper action to -1 or +1\n\n    Returns:\n        np.ndarray: Action array with normalized gripper action\n    \"\"\"\n    # Create a copy to avoid modifying the original\n    normalized_action = action.copy()\n\n    # Normalize the last action dimension to [-1,+1]\n    orig_low, orig_high = 0.0, 1.0\n    normalized_action[..., -1] = 2 * (normalized_action[..., -1] - orig_low) / (orig_high - orig_low) - 1\n\n    if binarize:\n        # Binarize to -1 or +1\n        normalized_action[..., -1] = np.sign(normalized_action[..., -1])\n\n    return normalized_action\n\n\n# def invert_gripper_action(action):\n#     \"\"\"\n#     Flips the sign of the gripper action (last dimension of action vector).\n#     This is necessary for some environments where -1 = open, +1 = close, since\n#     the RLDS dataloader aligns gripper actions such that 0 = close, 1 = open.\n#     \"\"\"\n#     action[..., -1] = action[..., -1] * -1.0\n#     return action\n\ndef invert_gripper_action(action: np.ndarray) -> np.ndarray:\n    \"\"\"\n    Flip the sign of the gripper action (last dimension of action vector).\n\n    This is necessary for environments where -1 = open, +1 = close, since\n    the RLDS dataloader aligns gripper actions such that 0 = close, 1 = open.\n\n    Args:\n        action: Action array with gripper action in the last dimension\n\n    Returns:\n        np.ndarray: Action array with inverted gripper action\n    \"\"\"\n    # Create a copy to avoid modifying the original\n    inverted_action = action.copy()\n\n    # Invert the gripper action\n    inverted_action[..., -1] =inverted_action[..., -1] *  -1.0\n\n    return inverted_action\n\ndef save_rollout_video(rollout_images, exp_name, task_name, step_idx, success ):\n    \"\"\"Saves an MP4 replay of an episode.\"\"\"\n    rollout_dir = f\"./rollouts/{exp_name}\" \n    os.makedirs(rollout_dir, exist_ok=True)\n    ran_id = random.randint(1, 10000)\n    #processed_task_description = task_description.lower().replace(\" \", \"_\").replace(\"\\n\", \"_\").replace(\".\", \"_\")[:50]\n    mp4_path = f\"{rollout_dir}/step={step_idx}--task={task_name}--success={success}--ran={ran_id}.mp4\"\n    video_writer = imageio.get_writer(mp4_path, fps=30)\n    for img in rollout_images:\n        video_writer.append_data(img)\n    video_writer.close()\n    print(f\"Saved rollout MP4 at path {mp4_path}\")\n    return mp4_path\n"
  },
  {
    "path": "siirl/utils/embodied/openvla_utils.py",
    "content": "\"\"\"Utils for evaluating OpenVLA or fine-tuned OpenVLA policies.\"\"\"\n\nimport filecmp\nimport json\nimport os\nimport shutil\nimport time\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport json_numpy\nimport numpy as np\nimport tensorflow as tf\nimport torch\nfrom loguru import logger\nfrom PIL import Image\nfrom transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor\n\n# Apply JSON numpy patch for serialization\njson_numpy.patch()\n\n# Configure NumPy print settings\nnp.set_printoptions(formatter={\"float\": lambda x: \"{0:0.3f}\".format(x)})\n\n\ndef update_auto_map(pretrained_checkpoint: str) -> None:\n    \"\"\"\n    Update the AutoMap configuration in the checkpoint config.json file.\n\n    This loads the config.json file inside the checkpoint directory and overwrites\n    the AutoConfig and AutoModelForVision2Seq fields to use OpenVLA-specific classes.\n\n    Uses file locking and atomic write to ensure thread-safety and prevent corruption.\n\n    Args:\n        pretrained_checkpoint: Path to the checkpoint directory\n    \"\"\"\n    if not os.path.isdir(pretrained_checkpoint):\n        return\n\n    config_path = os.path.join(pretrained_checkpoint, \"config.json\")\n    if not os.path.exists(config_path):\n        logger.warning(f\"No config.json found at {config_path}\")\n        return\n\n    import fcntl\n    import tempfile\n    \n    lock_path = os.path.join(pretrained_checkpoint, \".config.json.lock\")\n    max_retries = 5\n    retry_delay = 1.0\n    \n    for attempt in range(max_retries):\n        try:\n            # Acquire file lock for safe concurrent access\n            with open(lock_path, 'w') as lock_file:\n                # Wait up to 30 seconds for the lock\n                fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)\n                \n                try:\n                    # Re-check if config already has correct auto_map (another process may have updated it)\n                    with open(config_path, \"r\") as f:\n                        config = json.load(f)\n                    \n                    expected_auto_map = {\n                        \"AutoConfig\": \"configuration_prismatic.OpenVLAConfig\",\n                        \"AutoModelForVision2Seq\": \"modeling_prismatic.OpenVLAForActionPrediction\",\n                    }\n                    \n                    # If already correctly configured, skip update\n                    if config.get(\"auto_map\") == expected_auto_map:\n                        logger.info(f\"config.json already has correct auto_map, skipping update\")\n                        return\n                    \n                    # Create timestamped backup (only if we need to update)\n                    timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S_%f\")\n                    backup_path = os.path.join(pretrained_checkpoint, f\"config.json.back.{timestamp}\")\n                    shutil.copy2(config_path, backup_path)\n                    logger.info(f\"Created backup of original config at: {os.path.abspath(backup_path)}\")\n\n                    # Update the config\n                    config[\"auto_map\"] = expected_auto_map\n\n                    # Atomic write: write to temp file, then rename\n                    # This ensures readers never see a partial/corrupted file\n                    fd, temp_path = tempfile.mkstemp(\n                        dir=pretrained_checkpoint, \n                        prefix=\".config.json.tmp.\",\n                        suffix=\".json\"\n                    )\n                    try:\n                        with os.fdopen(fd, 'w') as f:\n                            json.dump(config, f, indent=2)\n                            f.flush()\n                            os.fsync(f.fileno())  # Ensure data is written to disk\n                        \n                        # Atomic rename - this is atomic on POSIX systems\n                        os.replace(temp_path, config_path)\n                        \n                        logger.info(f\"Updated config.json at: {os.path.abspath(config_path)}\")\n                        logger.info(\"Changes made:\")\n                        logger.info('  - Set AutoConfig to \"configuration_prismatic.OpenVLAConfig\"')\n                        logger.info('  - Set AutoModelForVision2Seq to \"modeling_prismatic.OpenVLAForActionPrediction\"')\n                        return\n                    except Exception as e:\n                        # Clean up temp file if something went wrong\n                        if os.path.exists(temp_path):\n                            os.unlink(temp_path)\n                        raise\n                finally:\n                    # Release lock\n                    fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)\n                    \n        except (json.JSONDecodeError, IOError) as e:\n            if attempt < max_retries - 1:\n                logger.warning(f\"Attempt {attempt + 1}/{max_retries} failed: {e}. Retrying in {retry_delay}s...\")\n                time.sleep(retry_delay)\n                retry_delay *= 1.5  # Exponential backoff\n            else:\n                logger.error(f\"Failed to update config.json after {max_retries} attempts: {e}\")\n                raise\n        except Exception as e:\n            logger.error(f\"Unexpected error updating config.json: {e}\")\n            raise\n\n\ndef check_identical_files(path1: Union[str, Path], path2: Union[str, Path]) -> bool:\n    \"\"\"\n    Check if two files are identical in content.\n\n    Args:\n        path1: Path to the first file\n        path2: Path to the second file\n\n    Returns:\n        bool: True if files are identical, False otherwise\n    \"\"\"\n    path1, path2 = Path(path1), Path(path2)\n\n    # First check if file sizes match\n    if path1.stat().st_size != path2.stat().st_size:\n        return False\n\n    # Check if contents match\n    return filecmp.cmp(path1, path2, shallow=False)\n\n\ndef _handle_file_sync(curr_filepath: str, checkpoint_filepath: str, file_type: str) -> None:\n    \"\"\"\n    Handle syncing of files between current directory and checkpoint.\n\n    Creates backups if files exist but differ, and copies current versions to checkpoint.\n\n    Args:\n        curr_filepath: Path to the current file version\n        checkpoint_filepath: Path where the file should be in the checkpoint\n        file_type: Description of the file type for logging\n    \"\"\"\n    if os.path.exists(checkpoint_filepath):\n        # Check if existing files are identical\n        match = check_identical_files(curr_filepath, checkpoint_filepath)\n\n        if not match:\n            logger.info(\n                \"\\n------------------------------------------------------------------------------------------------\\n\"\n                f\"Found mismatch between:\\n\"\n                f\"Current:   {curr_filepath}\\n\"\n                f\"Checkpoint: {checkpoint_filepath}\\n\"\n            )\n\n            # Create timestamped backup\n            timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n            backup_path = f\"{checkpoint_filepath}.back.{timestamp}\"\n            shutil.copy2(checkpoint_filepath, backup_path)\n            logger.info(f\"Created backup of original checkpoint file at: {os.path.abspath(backup_path)}\")\n\n            # Copy current version to checkpoint directory\n            shutil.copy2(curr_filepath, checkpoint_filepath)\n            logger.info(f\"Copied current version to checkpoint at: {os.path.abspath(checkpoint_filepath)}\")\n            logger.info(\n                f\"Changes complete. The checkpoint will now use the current version of {file_type}\"\n                \"\\n------------------------------------------------------------------------------------------------\\n\"\n            )\n    else:\n        # If file doesn't exist in checkpoint directory, copy it\n        shutil.copy2(curr_filepath, checkpoint_filepath)\n        logger.info(\n            \"\\n------------------------------------------------------------------------------------------------\\n\"\n            f\"No {file_type} found in checkpoint directory.\\n\"\n            f\"Copied current version from: {curr_filepath}\\n\"\n            f\"To checkpoint location: {os.path.abspath(checkpoint_filepath)}\"\n            \"\\n------------------------------------------------------------------------------------------------\\n\"\n        )\n\n\ndef check_model_logic_mismatch(pretrained_checkpoint: str) -> None:\n    \"\"\"\n    Check and sync model logic files between current code and checkpoint.\n\n    Handles the relationship between current and checkpoint versions of both\n    modeling_prismatic.py and configuration_prismatic.py:\n    - If checkpoint file exists and differs: creates backup and copies current version\n    - If checkpoint file doesn't exist: copies current version\n\n    Args:\n        pretrained_checkpoint: Path to the checkpoint directory\n    \"\"\"\n    if not os.path.isdir(pretrained_checkpoint):\n        return\n\n    # Find current files\n    curr_files = {\"modeling_prismatic.py\": None, \"configuration_prismatic.py\": None}\n\n    for root, _, files in os.walk(\"./prismatic/\"):\n        for filename in curr_files.keys():\n            if filename in files and curr_files[filename] is None:\n                curr_files[filename] = os.path.join(root, filename)\n\n    # Check and handle each file\n    for filename, curr_filepath in curr_files.items():\n        if curr_filepath is None:\n            logger.warning(f\"`{filename}` is not found anywhere in the current directory.\")\n            continue\n\n        checkpoint_filepath = os.path.join(pretrained_checkpoint, filename)\n        _handle_file_sync(curr_filepath, checkpoint_filepath, filename)\n\n\n\n\n\n\n\n"
  },
  {
    "path": "siirl/utils/embodied/video_emb.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numpy as np\nimport torch\n\nfrom loguru import logger\n\nimport src.datasets.utils.video.transforms as video_transforms\nimport src.datasets.utils.video.volume_transforms as volume_transforms\nfrom src.models.vision_transformer import vit_giant_xformers_rope\n\n# Constants for video normalization\nIMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)\nIMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)\n\n\nclass VideoEmbeddingModel:\n    \"\"\"\n    A self-contained class to load the V-JEPA model, preprocess video frames,\n    and extract embeddings. Each instance is tied to a specific GPU device.\n    \"\"\"\n    def __init__(self, model_path: str, img_size: int = 384, device_id: int = 0, enable_fp16: bool = False) -> None:\n        self.model_path = model_path\n        self.img_size = img_size\n        self.device = f\"cuda:{device_id}\"\n        self.auto_cast_dtype = torch.float16 if enable_fp16 else torch.float32\n        logger.info(f\"Initializing embedding model on device: {self.device}, enbale_fp16={enable_fp16}. \"\n                    \"It will take several minutes, please be patient...\")\n        self.pt_video_transform, self.model_pt = self._create_model_instance()\n        self.embedding_dim = self.model_pt.norm.bias.shape[0]\n        self.num_frames_for_embedding = 64\n        logger.info(f\"Embedding model loaded successfully on {self.device}\")\n\n    def _build_pt_video_transform(self):\n        \"\"\"Builds the video transformation pipeline for the model.\"\"\"\n        short_side_size = int(256.0 / 224 * self.img_size)\n        eval_transform = video_transforms.Compose(\n            [\n                video_transforms.Resize(short_side_size, interpolation=\"bilinear\"),\n                video_transforms.CenterCrop(size=(self.img_size, self.img_size)),\n                volume_transforms.ClipToTensor(),\n                video_transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),\n            ]\n        )\n        return eval_transform\n\n    def _load_pretrained_vjepa_pt_weights(self, model, pretrained_weights):\n        \"\"\"Loads pretrained weights into the V-JEPA model.\"\"\"\n        try:\n            pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location=self.device)[\"encoder\"]\n            pretrained_dict = {k.replace(\"module.\", \"\"): v for k, v in pretrained_dict.items()}\n            pretrained_dict = {k.replace(\"backbone.\", \"\"): v for k, v in pretrained_dict.items()}\n            msg = model.load_state_dict(pretrained_dict, strict=False)\n            logger.info(f\"Pretrained weights found at {pretrained_weights} and loaded with msg: {msg}\")\n        except Exception as e:\n            logger.error(f\"Failed to load pretrained weights from {pretrained_weights}: {e}\")\n            raise\n\n    def _create_model_instance(self):\n        \"\"\"Creates and prepares the V-JEPA model instance.\"\"\"\n        model_pt = vit_giant_xformers_rope(\n            img_size=(self.img_size, self.img_size),\n            num_frames=64)\n        self._load_pretrained_vjepa_pt_weights(model_pt, self.model_path)\n        model_pt.eval().to(self.device)\n        pt_video_transform = self._build_pt_video_transform()\n        return pt_video_transform, model_pt\n\n    def offload_to_host(self):\n        \"\"\"Offloads the model to CPU to free up GPU memory.\"\"\"\n        self.model_pt.to(\"cpu\")\n        torch.cuda.empty_cache()\n        logger.info(f\"Video embedding model offloaded to CPU.\")\n\n    def load_to_device(self):\n        \"\"\"Loads the model back to the assigned GPU device.\"\"\"\n        self.model_pt.to(self.device)\n        logger.info(f\"Video embedding model loaded back to {self.device}.\")\n\n    def extract_video_embedding(self, video_tensor):\n        \"\"\"\n        Extracts the embedding from a given video tensor.\n        Args:\n            video_tensor (torch.Tensor): A tensor of shape (T, C, H, W).\n        Returns:\n            np.ndarray: The computed embedding vector.\n        \"\"\"\n        if video_tensor is None:\n            return None\n        with torch.inference_mode():\n            # The transform expects a tensor on the correct device\n            x = self.pt_video_transform(video_tensor.to(self.device)).to(self.device).unsqueeze(0)\n            with torch.amp.autocast('cuda', dtype=self.auto_cast_dtype):\n                embedding = self.model_pt(x)\n            return embedding.mean(dim=1).to(torch.float32).squeeze(0).cpu().numpy()\n\n    def extract_video_embedding_batch(self, video_tensor_list):\n        \"\"\"\n        Extracts embeddings for a batch of video tensors.\n        Args:\n            video_tensor_list (list of torch.Tensor or None): List of video tensors.\n        Returns:\n            list of np.ndarray: List of computed embedding vectors.\n        \"\"\"\n        with torch.inference_mode():\n            input_list = [self.pt_video_transform(v.to(self.device)).to(self.device) for v in video_tensor_list]\n            x = torch.stack(input_list, dim=0)\n            with torch.amp.autocast('cuda', dtype=self.auto_cast_dtype):\n                embedding = self.model_pt(x)\n            return [e.mean(dim=0).to(torch.float32).cpu().numpy() for e in embedding]\n        \n    def get_embeddings(self, batch_names, batch_frames):\n        \"\"\"\n        Processes video frames in memory and returns embeddings for each task.\n        Handles missing frames gracefully and batches all embedding extractions for efficiency.\n        All videos are processed in a single batch after padding shorter videos.\n        Args:\n            batch_names (list of str): List of task names or identifiers.\n            batch_frames (list of list of np.ndarray): List of video frames for each task.\n        Returns:\n            list of np.ndarray: List of embedding vectors for each task.\n        \"\"\"\n        assert len(batch_names) == len(batch_frames), \"Names and frames lists must be of the same length.\"\n        embedding_list = [np.zeros((self.embedding_dim), dtype=np.float32) for _ in batch_names]\n        \n        video_tensors_to_process = []\n        original_indices = []\n\n        for idx, (name, frames) in enumerate(zip(batch_names, batch_frames)):\n            if not frames:\n                logger.warning(f\"== Found 0 frames for video {name}, returning zero embedding ==\")\n                continue\n            try:\n                total_frames = len(frames)\n                if total_frames >= self.num_frames_for_embedding:\n                    selected_indices = np.linspace(0, total_frames - 1, num=self.num_frames_for_embedding, dtype=int)\n                    sampled_frames = [frames[i] for i in selected_indices]\n                else:\n                    logger.warning(f\"Video {name} has only {total_frames} frames. Padding to {self.num_frames_for_embedding}.\")\n                    indices = np.arange(total_frames)\n                    padded_indices = np.resize(indices, self.num_frames_for_embedding)\n                    sampled_frames = [frames[i] for i in padded_indices]\n                # Convert list of numpy arrays to a single torch tensor\n                # The frames are (H, W, C), convert to (T, C, H, W)\n                video_tensor = torch.from_numpy(np.stack(sampled_frames)).permute(0, 3, 1, 2)\n                video_tensors_to_process.append(video_tensor)\n                original_indices.append(idx)\n            except Exception as e:\n                logger.error(f\"Error processing frames for {name}: {e}\")\n\n        if not video_tensors_to_process:\n            return embedding_list\n\n        batch_embeddings = self.extract_video_embedding_batch(video_tensors_to_process)\n        for i, emb in zip(original_indices, batch_embeddings):\n            embedding_list[i] = emb\n\n        return embedding_list\n"
  },
  {
    "path": "siirl/utils/experimental/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/utils/experimental/torch_functional.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional, Tuple\n\nimport torch\n\n\ndef _fused_linear_for_ppo_fwd(hidden_states: torch.FloatTensor, vocab_weights: torch.FloatTensor, input_ids: torch.LongTensor, temperature: float = 1.0) -> Tuple[torch.FloatTensor, torch.FloatTensor]:\n    logits = (hidden_states @ vocab_weights.t()) / temperature\n    orig_dtype = logits.dtype\n    logits = logits.to(torch.float32)\n\n    # Slower but more numerically stable to do log_softmax than probs.log()\n    probs = logits.softmax(dim=-1)\n    log_probs = logits.log_softmax(dim=-1)\n\n    token_log_probs = log_probs.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1)\n    entropy = torch.logsumexp(logits, dim=-1) - torch.sum(probs * logits, dim=-1)\n\n    return token_log_probs.to(orig_dtype), entropy.to(orig_dtype)\n\n\ndef _fused_linear_for_ppo_bwd(\n    dlog_probs: Optional[torch.FloatTensor],\n    dentropy: Optional[torch.FloatTensor],\n    hidden_states: torch.FloatTensor,\n    vocab_weights: torch.FloatTensor,\n    input_ids: torch.LongTensor,\n    temperature: float = 1.0,\n) -> Tuple[torch.FloatTensor, torch.FloatTensor]:\n    logits = (hidden_states @ vocab_weights.t()) / temperature\n    orig_dtype = logits.dtype\n    logits = logits.to(torch.float32)\n\n    probs = logits.softmax(dim=-1)\n\n    dlogits = 0\n\n    # Gradient from log_probs\n    if dlog_probs is not None:\n        one_hot_input = torch.zeros_like(logits).scatter_(-1, input_ids.unsqueeze(-1), 1)\n        dlogits += dlog_probs.to(torch.float32).unsqueeze(-1) * (one_hot_input - probs)\n\n    # Gradient from entropy\n    if dentropy is not None:\n        log_probs = logits.log_softmax(dim=-1)\n        entropy = torch.logsumexp(logits, dim=-1) - torch.sum(probs * logits, dim=-1)\n        dlogits += probs * (log_probs + entropy.unsqueeze(-1)) * (-dentropy.unsqueeze(-1))\n\n    dlogits = dlogits.to(orig_dtype) / temperature\n\n    dhidden_states = dlogits @ vocab_weights\n    dvocab_weights = dlogits.t() @ hidden_states\n\n    return dhidden_states, dvocab_weights\n\n\nclass FusedLinearForPPOFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        hidden_states: torch.FloatTensor,\n        vocab_weights: torch.FloatTensor,\n        input_ids: torch.LongTensor,\n        temperature: float = 1.0,\n        chunk_size: int = 512,\n    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:\n        ctx.set_materialize_grads(False)\n\n        # Cast to a 2D tensor of the shape [T, D] for ease of working\n        orig_ndim = hidden_states.ndim\n        assert orig_ndim in (2, 3), f\"Invalid hidden_states shape, received {hidden_states.shape}\"\n\n        orig_batch_size = -1\n        if orig_ndim == 3:\n            assert input_ids.ndim == 2, f\"input_ids shape doesn't match, {hidden_states.shape} {input_ids.shape}\"\n            orig_batch_size = hidden_states.shape[0]\n            hidden_states = hidden_states.flatten(0, 1)\n            input_ids = input_ids.flatten(0, 1)\n\n        T = hidden_states.shape[0]\n\n        # Allocate memory for outputs\n        output_requires_grad = hidden_states.requires_grad or vocab_weights.requires_grad\n        log_probs = hidden_states.new_zeros(T, requires_grad=output_requires_grad)\n        entropy = hidden_states.new_zeros(T, requires_grad=output_requires_grad)\n\n        # Perform forward one chunk at a time\n        for chunk_start in range(0, T, chunk_size):\n            chunk_end = min(chunk_start + chunk_size, T)\n\n            chunk_log_probs, chunk_entropy = _fused_linear_for_ppo_fwd(\n                hidden_states=hidden_states[chunk_start:chunk_end],\n                vocab_weights=vocab_weights,\n                input_ids=input_ids[chunk_start:chunk_end],\n                temperature=temperature,\n            )\n            log_probs[chunk_start:chunk_end] = chunk_log_probs\n            entropy[chunk_start:chunk_end] = chunk_entropy\n\n        # Cast the output back to the original input dimension\n        if orig_ndim == 3:\n            log_probs = log_probs.view(orig_batch_size, -1)\n            entropy = entropy.view(orig_batch_size, -1)\n\n        ctx.save_for_backward(hidden_states, vocab_weights, input_ids)\n        ctx.orig_batch_size = orig_batch_size\n        ctx.orig_ndim = orig_ndim\n        ctx.temperature = temperature\n        ctx.chunk_size = chunk_size\n\n        return log_probs, entropy\n\n    @staticmethod\n    def backward(ctx, dlog_probs: Optional[torch.FloatTensor], dentropy: Optional[torch.FloatTensor]):\n        assert dlog_probs is not None or dentropy is not None\n\n        hidden_states, vocab_weights, input_ids = ctx.saved_tensors\n        orig_batch_size = ctx.orig_batch_size\n        orig_ndim = ctx.orig_ndim\n        temperature = ctx.temperature\n        chunk_size = ctx.chunk_size\n\n        # Here orig_ndim refers to the orig_ndim of hidden_states\n        if orig_ndim == 3:\n            if dlog_probs is not None:\n                dlog_probs = dlog_probs.flatten()\n            if dentropy is not None:\n                dentropy = dentropy.flatten()\n\n        T = hidden_states.shape[0]\n\n        # Allocate memory for outputs\n        dhidden_states = None\n        if hidden_states.requires_grad:\n            dhidden_states = torch.zeros_like(hidden_states)\n        dvocab_weights = None\n        if vocab_weights.requires_grad:\n            dvocab_weights = torch.zeros_like(vocab_weights)\n\n        # Perform backward one chunk at a time\n        for chunk_start in range(0, T, chunk_size):\n            chunk_end = min(chunk_start + chunk_size, T)\n            chunk_dlog_probs = None\n            if dlog_probs is not None:\n                chunk_dlog_probs = dlog_probs[chunk_start:chunk_end]\n            chunk_dentropy = None\n            if dentropy is not None:\n                chunk_dentropy = dentropy[chunk_start:chunk_end]\n\n            h, v = _fused_linear_for_ppo_bwd(\n                dlog_probs=chunk_dlog_probs,\n                dentropy=chunk_dentropy,\n                hidden_states=hidden_states[chunk_start:chunk_end],\n                vocab_weights=vocab_weights,\n                input_ids=input_ids[chunk_start:chunk_end],\n                temperature=temperature,\n            )\n\n            if hidden_states.requires_grad:\n                dhidden_states[chunk_start:chunk_end] += h\n            if vocab_weights.requires_grad:\n                dvocab_weights += v\n\n        # Cast the output back to the original input dimension\n        if orig_ndim == 3 and hidden_states.requires_grad:\n            hidden_size = hidden_states.shape[-1]\n            dhidden_states = dhidden_states.view(orig_batch_size, -1, hidden_size)\n\n        return (\n            dhidden_states,  # hidden_states\n            dvocab_weights,  # vocab_weights\n            None,  # input_ids\n            None,  # temperature\n            None,  # chunk_size\n        )\n\n\nclass FusedLinearForPPO(torch.nn.Module):\n    def __init__(self, chunk_size: int = 512):\n        super().__init__()\n\n        self.chunk_size = chunk_size\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        vocab_weights: torch.FloatTensor,\n        input_ids: torch.LongTensor,\n        temperature: float = 1.0,\n    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:\n        input_ids = input_ids.to(torch.int64)\n        return FusedLinearForPPOFunction.apply(\n            hidden_states,\n            vocab_weights,\n            input_ids,\n            temperature,\n            self.chunk_size,\n        )\n"
  },
  {
    "path": "siirl/utils/extras/__init__.py",
    "content": ""
  },
  {
    "path": "siirl/utils/extras/device.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# This code is inspired by the torchtune.\n# https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py\n#\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license in https://github.com/pytorch/torchtune/blob/main/LICENSE\n\nimport logging\n\nimport torch\n\nlogger = logging.getLogger(__name__)\n\n\ndef is_torch_npu_available() -> bool:\n    \"\"\"Check the availability of NPU\"\"\"\n    try:\n        import torch_npu  # noqa: F401\n\n        return torch.npu.is_available()\n    except ImportError:\n        return False\n\n\nis_cuda_available = torch.cuda.is_available()\nis_npu_available = is_torch_npu_available()\n\n\ndef get_device_name() -> str:\n    \"\"\"Function that gets the torch.device based on the current machine.\n    This currently only supports CPU, CUDA, NPU.\n    Returns:\n        device\n    \"\"\"\n    if is_cuda_available:\n        device = \"cuda\"\n    elif is_npu_available:\n        device = \"npu\"\n    else:\n        device = \"cpu\"\n    return device\n\n\ndef get_torch_device() -> any:\n    \"\"\"Return the corresponding torch attribute based on the device type string.\n    Returns:\n        module: The corresponding torch device namespace, or torch.cuda if not found.\n    \"\"\"\n    device_name = get_device_name()\n    try:\n        return getattr(torch, device_name)\n    except AttributeError:\n        logger.warning(f\"Device namespace '{device_name}' not found in torch, try to load torch.cuda.\")\n        return torch.cuda\n\n\ndef get_device_id() -> int:\n    \"\"\"Return current device id based on the device type.\n    Returns:\n        device index\n    \"\"\"\n    return get_torch_device().current_device()\n\ndef get_nccl_backend() -> str:\n    \"\"\"Return nccl backend type based on the device type.\n    Returns:\n        nccl backend type string.\n    \"\"\"\n    if is_cuda_available:\n        return \"nccl\"\n    elif is_npu_available:\n        return \"hccl\"\n    else:\n        raise RuntimeError(f\"No available nccl backend found on device type {get_device_name()}.\")\n    return get_torch_device().current_device()\n\ndef device_synchronize():\n    \"\"\"\n    Synchronize the current device to ensure that all previously\n    launched operations have completed.\n\n    - If CUDA is available, it calls `torch.cuda.synchronize()`.\n    - If NPU is available, it calls `torch.npu.synchronize()`.\n\n    \"\"\"\n    if is_cuda_available:\n        torch.cuda.synchronize()\n    elif is_npu_available:\n        torch.npu.synchronize()\n\ndef set_expandable_segments(enable: bool) -> None:\n    \"\"\"Enable or disable expandable segments for cuda.\n    Args:\n        enable (bool): Whether to enable expandable segments. Used to avoid OOM.\n    \"\"\"\n    if is_cuda_available:\n        torch.cuda.memory._set_allocator_settings(f\"expandable_segments:{enable}\")"
  },
  {
    "path": "siirl/utils/extras/fs.py",
    "content": "#!/usr/bin/env python\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# -*- coding: utf-8 -*-\n\"\"\"File-system agnostic IO APIs\"\"\"\n\nimport hashlib\nimport os\nimport shutil\nimport tempfile\n\ntry:\n    from hdfs_io import copy, exists, makedirs  # for internal use only\nexcept ImportError:\n    from .hdfs_io import copy, exists, makedirs\n\n__all__ = [\"copy\", \"exists\", \"makedirs\"]\n\n_HDFS_PREFIX = \"hdfs://\"\n\n\ndef is_non_local(path):\n    \"\"\"Check if a path is a non-local (HDFS) path.\n\n    Args:\n        path (str): The path to check.\n\n    Returns:\n        bool: True if the path is an HDFS path, False otherwise.\n    \"\"\"\n    return path.startswith(_HDFS_PREFIX)\n\n\ndef md5_encode(path: str) -> str:\n    \"\"\"Generate an MD5 hash of a path string.\n\n    This function is used to create unique identifiers for paths, typically\n    for creating cache directories or lock files.\n\n    Args:\n        path (str): The path to encode.\n\n    Returns:\n        str: The hexadecimal MD5 hash of the path.\n    \"\"\"\n    return hashlib.md5(path.encode()).hexdigest()\n\n\ndef get_local_temp_path(hdfs_path: str, cache_dir: str) -> str:\n    \"\"\"Generate a unique local cache path for an HDFS resource.\n    Creates a MD5-hashed subdirectory in cache_dir to avoid name conflicts,\n    then returns path combining this subdirectory with the HDFS basename.\n\n    Args:\n        hdfs_path (str): Source HDFS path to be cached\n        cache_dir (str): Local directory for storing cached files\n\n    Returns:\n        str: Absolute local filesystem path in format:\n            {cache_dir}/{md5(hdfs_path)}/{basename(hdfs_path)}\n    \"\"\"\n    # make a base64 encoding of hdfs_path to avoid directory conflict\n    encoded_hdfs_path = md5_encode(hdfs_path)\n    temp_dir = os.path.join(cache_dir, encoded_hdfs_path)\n    os.makedirs(temp_dir, exist_ok=True)\n    dst = os.path.join(temp_dir, os.path.basename(hdfs_path))\n    return dst\n\n\ndef verify_copy(src: str, dest: str) -> bool:\n    \"\"\"\n    verify the copy of src to dest by comparing their sizes and file structures.\n\n    return:\n        bool: True if the copy is verified, False otherwise.\n    \"\"\"\n    if not os.path.exists(src):\n        return False\n    if not os.path.exists(dest):\n        return False\n\n    if os.path.isfile(src) != os.path.isfile(dest):\n        return False\n\n    if os.path.isfile(src):\n        src_size = os.path.getsize(src)\n        dest_size = os.path.getsize(dest)\n        if src_size != dest_size:\n            return False\n        return True\n\n    src_files = set()\n    dest_files = set()\n\n    for root, dirs, files in os.walk(src):\n        rel_path = os.path.relpath(root, src)\n        dest_root = os.path.join(dest, rel_path) if rel_path != \".\" else dest\n\n        if not os.path.exists(dest_root):\n            return False\n\n        for entry in os.listdir(root):\n            src_entry = os.path.join(root, entry)\n            src_files.add(os.path.relpath(src_entry, src))\n\n        for entry in os.listdir(dest_root):\n            dest_entry = os.path.join(dest_root, entry)\n            dest_files.add(os.path.relpath(dest_entry, dest))\n\n    if src_files != dest_files:\n        return False\n\n    for rel_path in src_files:\n        src_entry = os.path.join(src, rel_path)\n        dest_entry = os.path.join(dest, rel_path)\n\n        if os.path.isdir(src_entry) != os.path.isdir(dest_entry):\n            return False\n\n        if os.path.isfile(src_entry):\n            src_size = os.path.getsize(src_entry)\n            dest_size = os.path.getsize(dest_entry)\n            if src_size != dest_size:\n                return False\n\n    return True\n\n\ndef copy_to_shm(src: str):\n    \"\"\"\n    Load the model into   /dev/shm   to make the process of loading the model multiple times more efficient.\n    \"\"\"\n    shm_model_root = \"/dev/shm/siirl-cache/\"\n    src_abs = os.path.abspath(os.path.normpath(src))\n    dest = os.path.join(shm_model_root, hashlib.md5(src_abs.encode(\"utf-8\")).hexdigest())\n    os.makedirs(dest, exist_ok=True)\n    dest = os.path.join(dest, os.path.basename(src_abs))\n    if os.path.exists(dest) and verify_copy(src, dest):\n        # inform user and depends on him\n        print(f\"[WARNING]: The memory model path {dest} already exists. If it is not you want, please clear it and restart the task.\")\n    else:\n        if os.path.isdir(src):\n            shutil.copytree(src, dest, symlinks=False, dirs_exist_ok=True)\n        else:\n            shutil.copy2(src, dest)\n    return dest\n\n\ndef _record_directory_structure(folder_path):\n    record_file = os.path.join(folder_path, \".directory_record.txt\")\n    with open(record_file, \"w\") as f:\n        for root, dirs, files in os.walk(folder_path):\n            for dir_name in dirs:\n                relative_dir = os.path.relpath(os.path.join(root, dir_name), folder_path)\n                f.write(f\"dir:{relative_dir}\\n\")\n            for file_name in files:\n                if file_name != \".directory_record.txt\":\n                    relative_file = os.path.relpath(os.path.join(root, file_name), folder_path)\n                    f.write(f\"file:{relative_file}\\n\")\n    return record_file\n\n\ndef _check_directory_structure(folder_path, record_file):\n    if not os.path.exists(record_file):\n        return False\n    existing_entries = set()\n    for root, dirs, files in os.walk(folder_path):\n        for dir_name in dirs:\n            relative_dir = os.path.relpath(os.path.join(root, dir_name), folder_path)\n            existing_entries.add(f\"dir:{relative_dir}\")\n        for file_name in files:\n            if file_name != \".directory_record.txt\":\n                relative_file = os.path.relpath(os.path.join(root, file_name), folder_path)\n                existing_entries.add(f\"file:{relative_file}\")\n    with open(record_file) as f:\n        recorded_entries = set(f.read().splitlines())\n    return existing_entries == recorded_entries\n\n\ndef copy_to_local(src: str, cache_dir=None, filelock=\".file.lock\", verbose=False, always_recopy=False, use_shm: bool = False) -> str:\n    \"\"\"Copy files/directories from HDFS to local cache with validation.\n\n    Args:\n        src (str): Source path - HDFS path (hdfs://...) or local filesystem path\n        cache_dir (str, optional): Local directory for cached files. Uses system tempdir if None\n        filelock (str): Base name for file lock. Defaults to \".file.lock\"\n        verbose (bool): Enable copy operation logging. Defaults to False\n        always_recopy (bool): Force fresh copy ignoring cache. Defaults to False\n        use_shm (bool): Enable shared memory copy. Defaults to False\n\n    Returns:\n        str: Local filesystem path to copied resource\n    \"\"\"\n    # Save to a local path for persistence.\n    local_path = copy_local_path_from_hdfs(src, cache_dir, filelock, verbose, always_recopy)\n    # Load into shm to improve efficiency.\n    if use_shm:\n        return copy_to_shm(local_path)\n    return local_path\n\n\ndef copy_local_path_from_hdfs(src: str, cache_dir=None, filelock=\".file.lock\", verbose=False, always_recopy=False) -> str:\n    \"\"\"Deprecated. Please use copy_to_local instead.\"\"\"\n    from filelock import FileLock\n\n    assert src[-1] != \"/\", f\"Make sure the last char in src is not / because it will cause error. Got {src}\"\n\n    if is_non_local(src):\n        # download from hdfs to local\n        if cache_dir is None:\n            # get a temp folder\n            cache_dir = tempfile.gettempdir()\n        os.makedirs(cache_dir, exist_ok=True)\n        assert os.path.exists(cache_dir)\n        local_path = get_local_temp_path(src, cache_dir)\n        # get a specific lock\n        filelock = md5_encode(src) + \".lock\"\n        lock_file = os.path.join(cache_dir, filelock)\n        with FileLock(lock_file=lock_file):\n            if always_recopy and os.path.exists(local_path):\n                if os.path.isdir(local_path):\n                    shutil.rmtree(local_path, ignore_errors=True)\n                else:\n                    os.remove(local_path)\n            if not os.path.exists(local_path):\n                if verbose:\n                    print(f\"Copy from {src} to {local_path}\")\n                copy(src, local_path)\n                if os.path.isdir(local_path):\n                    _record_directory_structure(local_path)\n            elif os.path.isdir(local_path):\n                # always_recopy=False, local path exists, and it is a folder: check whether there is anything missed\n                record_file = os.path.join(local_path, \".directory_record.txt\")\n                if not _check_directory_structure(local_path, record_file):\n                    if verbose:\n                        print(f\"Recopy from {src} to {local_path} due to missing files or directories.\")\n                    shutil.rmtree(local_path, ignore_errors=True)\n                    copy(src, local_path)\n                    _record_directory_structure(local_path)\n        return local_path\n    else:\n        return src\n\ndef local_mkdir_safe(path):\n    \"\"\"_summary_\n    Thread-safe directory creation function that ensures the directory is created\n    even if multiple processes attempt to create it simultaneously.\n\n    Args:\n        path (str): The path to create a directory at.\n    \"\"\"\n\n    from filelock import FileLock\n\n    if not os.path.isabs(path):\n        working_dir = os.getcwd()\n        path = os.path.join(working_dir, path)\n\n    # Using hash value of path as lock file name to avoid long file name\n    lock_filename = f\"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock\"\n    lock_path = os.path.join(tempfile.gettempdir(), lock_filename)\n\n    try:\n        with FileLock(lock_path, timeout=60):  # Add timeout\n            # make a new dir\n            os.makedirs(path, exist_ok=True)\n    except Exception as e:\n        print(f\"Warning: Failed to acquire lock for {path}: {e}\")\n        # Even if the lock is not acquired, try to create the directory\n        os.makedirs(path, exist_ok=True)\n\n    return path"
  },
  {
    "path": "siirl/utils/extras/hdfs_io.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport shutil\nfrom loguru import logger\n\n\n_HDFS_PREFIX = \"hdfs://\"\n\n_HDFS_BIN_PATH = shutil.which(\"hdfs\")\n\n\ndef exists(path: str, **kwargs) -> bool:\n    r\"\"\"Works like os.path.exists() but supports hdfs.\n\n    Test whether a path exists. Returns False for broken symbolic links.\n\n    Args:\n        path (str): path to test\n\n    Returns:\n        bool: True if the path exists, False otherwise\n    \"\"\"\n    if _is_non_local(path):\n        return _exists(path, **kwargs)\n    return os.path.exists(path)\n\n\ndef _exists(file_path: str):\n    \"\"\"hdfs capable to check whether a file_path is exists\"\"\"\n    if file_path.startswith(\"hdfs\"):\n        return _run_cmd(_hdfs_cmd(f\"-test -e {file_path}\")) == 0\n    return os.path.exists(file_path)\n\n\ndef makedirs(name, mode=0o777, exist_ok=False, **kwargs) -> None:\n    r\"\"\"Works like os.makedirs() but supports hdfs.\n\n    Super-mkdir; create a leaf directory and all intermediate ones.  Works like\n    mkdir, except that any intermediate path segment (not just the rightmost)\n    will be created if it does not exist. If the target directory already\n    exists, raise an OSError if exist_ok is False. Otherwise no exception is\n    raised.  This is recursive.\n\n    Args:\n        name (str): directory to create\n        mode (int): file mode bits\n        exist_ok (bool): if True, do not raise an exception if the directory already exists\n        kwargs: keyword arguments for hdfs\n\n    \"\"\"\n    if _is_non_local(name):\n        # TODO(haibin.lin):\n        # - handle OSError for hdfs(?)\n        # - support exist_ok for hdfs(?)\n        _mkdir(name, **kwargs)\n    else:\n        os.makedirs(name, mode=mode, exist_ok=exist_ok)\n\n\ndef _mkdir(file_path: str) -> bool:\n    \"\"\"hdfs mkdir\"\"\"\n    if file_path.startswith(\"hdfs\"):\n        _run_cmd(_hdfs_cmd(f\"-mkdir -p {file_path}\"))\n    else:\n        os.makedirs(file_path, exist_ok=True)\n    return True\n\n\ndef copy(src: str, dst: str, **kwargs) -> bool:\n    r\"\"\"Works like shutil.copy() for file, and shutil.copytree for dir, and supports hdfs.\n\n    Copy data and mode bits (\"cp src dst\"). Return the file's destination.\n    The destination may be a directory.\n    If source and destination are the same file, a SameFileError will be\n    raised.\n\n    Arg:\n        src (str): source file path\n        dst (str): destination file path\n        kwargs: keyword arguments for hdfs copy\n\n    Returns:\n        str: destination file path\n\n    \"\"\"\n    if _is_non_local(src) or _is_non_local(dst):\n        # TODO(haibin.lin):\n        # - handle SameFileError for hdfs files(?)\n        # - return file destination for hdfs files\n        return _copy(src, dst)\n    else:\n        if os.path.isdir(src):\n            return shutil.copytree(src, dst, **kwargs)\n        else:\n            return shutil.copy(src, dst, **kwargs)\n\n\ndef _copy(from_path: str, to_path: str, timeout: int = None) -> bool:\n    if to_path.startswith(\"hdfs\"):\n        if from_path.startswith(\"hdfs\"):\n            returncode = _run_cmd(_hdfs_cmd(f\"-cp -f {from_path} {to_path}\"), timeout=timeout)\n        else:\n            returncode = _run_cmd(_hdfs_cmd(f\"-put -f {from_path} {to_path}\"), timeout=timeout)\n    else:\n        if from_path.startswith(\"hdfs\"):\n            returncode = _run_cmd(\n                _hdfs_cmd(\n                    f\"-get \\\n                {from_path} {to_path}\"\n                ),\n                timeout=timeout,\n            )\n        else:\n            try:\n                shutil.copy(from_path, to_path)\n                returncode = 0\n            except shutil.SameFileError:\n                returncode = 0\n            except Exception as e:\n                logger.warning(f\"copy {from_path} {to_path} failed: {e}\")\n                returncode = -1\n    return returncode == 0\n\n\ndef _run_cmd(cmd: str, timeout=None):\n    return os.system(cmd)\n\n\ndef _hdfs_cmd(cmd: str) -> str:\n    return f\"{_HDFS_BIN_PATH} dfs {cmd}\"\n\n\ndef _is_non_local(path: str):\n    return path.startswith(_HDFS_PREFIX)\n"
  },
  {
    "path": "siirl/utils/extras/import_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nUtilities to check if packages are available.\nWe assume package availability won't change during runtime.\n\"\"\"\n\nimport importlib.util\nfrom functools import cache\nfrom typing import List, Optional\n\n\n@cache\ndef is_megatron_core_available():\n    try:\n        mcore_spec = importlib.util.find_spec(\"megatron.core\")\n    except ModuleNotFoundError:\n        mcore_spec = None\n    return mcore_spec is not None\n\n\n@cache\ndef is_vllm_available():\n    try:\n        vllm_spec = importlib.util.find_spec(\"vllm\")\n    except ModuleNotFoundError:\n        vllm_spec = None\n    return vllm_spec is not None\n\n\n@cache\ndef is_sglang_available():\n    try:\n        sglang_spec = importlib.util.find_spec(\"sglang\")\n    except ModuleNotFoundError:\n        sglang_spec = None\n    return sglang_spec is not None\n\n\n@cache\ndef is_nvtx_available():\n    try:\n        nvtx_spec = importlib.util.find_spec(\"nvtx\")\n    except ModuleNotFoundError:\n        nvtx_spec = None\n    return nvtx_spec is not None\n\n\ndef import_external_libs(external_libs=None):\n    if external_libs is None:\n        return\n    if not isinstance(external_libs, List):\n        external_libs = [external_libs]\n    import importlib\n\n    for external_lib in external_libs:\n        importlib.import_module(external_lib)\n\n\ndef load_extern_type(file_path: Optional[str], type_name: Optional[str]):\n    \"\"\"Load a external data type based on the file path and type name\"\"\"\n    import importlib.util\n    import os\n\n    if not file_path:\n        return None\n\n    if not os.path.exists(file_path):\n        raise FileNotFoundError(f\"Custom type file '{file_path}' not found.\")\n\n    spec = importlib.util.spec_from_file_location(\"custom_module\", file_path)\n    module = importlib.util.module_from_spec(spec)\n    try:\n        spec.loader.exec_module(module)\n    except Exception as e:\n        raise RuntimeError(f\"Error loading module from '{file_path}'\") from e\n\n    if not hasattr(module, type_name):\n        raise AttributeError(f\"Custom type '{type_name}' not found in '{file_path}'.\")\n\n    return getattr(module, type_name)\n\n\ndef _get_qualified_name(func):\n    \"\"\"Get full qualified name including module and class (if any).\"\"\"\n    module = func.__module__\n    qualname = func.__qualname__\n    return f\"{module}.{qualname}\"\n\n\ndef deprecated(replacement: str = \"\"):\n    \"\"\"Decorator to mark APIs as deprecated.\"\"\"\n    import functools\n    import warnings\n\n    def decorator(func):\n        qualified_name = _get_qualified_name(func)\n\n        @functools.wraps(func)\n        def wrapped(*args, **kwargs):\n            msg = f\"Warning: API '{qualified_name}' is deprecated.\"\n            if replacement:\n                msg += f\" Please use '{replacement}' instead.\"\n            warnings.warn(msg, category=DeprecationWarning, stacklevel=2)\n            return func(*args, **kwargs)\n\n        return wrapped\n\n    return decorator\n"
  },
  {
    "path": "siirl/utils/extras/misc.py",
    "content": "import gc\nimport os\nfrom typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union\n\nimport torch\nimport torch.distributed as dist\nimport transformers.dynamic_module_utils\nfrom transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList\nfrom transformers.dynamic_module_utils import get_relative_imports\nfrom transformers.utils import (\n    is_torch_bf16_gpu_available,\n    is_torch_cuda_available,\n    is_torch_mps_available,\n    is_torch_npu_available,\n    is_torch_xpu_available,\n)\nfrom transformers.utils.versions import require_version\n\nfrom loguru import logger\nfrom siirl.utils.extras.packages import is_transformers_version_greater_than\n\n\n_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()\ntry:\n    _is_bf16_available = is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported())\nexcept Exception:\n    _is_bf16_available = False\n\nif TYPE_CHECKING:\n    from numpy.typing import NDArray\n\n    from siirl.params import ModelArguments\n\n\nclass AverageMeter:\n    r\"\"\"\n    Computes and stores the average and current value.\n    \"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\ndef check_version(requirement: str, mandatory: bool = False) -> None:\n    r\"\"\"\n    Optionally checks the package version.\n    \"\"\"\n    if is_env_enabled(\"DISABLE_VERSION_CHECK\") and not mandatory:\n        logger.warning(\"Version checking has been disabled, may lead to unexpected behaviors.\")\n        return\n\n    if mandatory:\n        hint = f\"To fix: run `pip install {requirement}`.\"\n    else:\n        hint = f\"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check.\"\n\n    require_version(requirement, hint)\n\n\ndef check_dependencies() -> None:\n    r\"\"\"\n    Checks the version of the required packages.\n    \"\"\"\n    check_version(\"transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0\")\n    check_version(\"datasets>=2.16.0,<=3.2.0\")\n    check_version(\"accelerate>=0.34.0,<=1.2.1\")\n    check_version(\"peft>=0.11.1,<=0.12.0\")\n    check_version(\"trl>=0.8.6,<=0.9.6\")\n    if is_transformers_version_greater_than(\"4.46.0\") and not is_transformers_version_greater_than(\"4.48.1\"):\n        logger.warning(\"There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.\")\n\n\ndef calculate_tps(\n    dataset: Sequence[Dict[str, Any]],\n    metrics: Dict[str, float],\n    stage: Literal[\"sft\", \"rm\"],\n) -> float:\n    r\"\"\"\n    Calculates effective tokens per second.\n    \"\"\"\n    effective_token_num = 0\n    for data in dataset:\n        if stage == \"sft\":\n            effective_token_num += len(data[\"input_ids\"])\n        elif stage == \"rm\":\n            effective_token_num += len(data[\"chosen_input_ids\"]) + len(data[\"rejected_input_ids\"])\n\n    result = effective_token_num * metrics[\"epoch\"] / metrics[\"train_runtime\"]\n    return result / dist.get_world_size() if dist.is_initialized() else result\n\n\ndef count_parameters(model: \"torch.nn.Module\") -> Tuple[int, int]:\n    r\"\"\"\n    Returns the number of trainable parameters and number of all parameters in the model.\n    \"\"\"\n    trainable_params, all_param = 0, 0\n    for param in model.parameters():\n        num_params = param.numel()\n        # if using DS Zero 3 and the weights are initialized empty\n        if num_params == 0 and hasattr(param, \"ds_numel\"):\n            num_params = param.ds_numel\n\n        # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize\n        if param.__class__.__name__ == \"Params4bit\":\n            if hasattr(param, \"quant_storage\") and hasattr(param.quant_storage, \"itemsize\"):\n                num_bytes = param.quant_storage.itemsize\n            elif hasattr(param, \"element_size\"):  # for older pytorch version\n                num_bytes = param.element_size()\n            else:\n                num_bytes = 1\n\n            num_params = num_params * 2 * num_bytes\n\n        all_param += num_params\n        if param.requires_grad:\n            trainable_params += num_params\n\n    return trainable_params, all_param\n\n\ndef get_current_device() -> \"torch.device\":\n    r\"\"\"\n    Gets the current available device.\n    \"\"\"\n    if is_torch_xpu_available():\n        device = \"xpu:{}\".format(os.environ.get(\"LOCAL_RANK\", \"0\"))\n    elif is_torch_npu_available():\n        device = \"npu:{}\".format(os.environ.get(\"LOCAL_RANK\", \"0\"))\n    elif is_torch_mps_available():\n        device = \"mps:{}\".format(os.environ.get(\"LOCAL_RANK\", \"0\"))\n    elif is_torch_cuda_available():\n        device = \"cuda:{}\".format(os.environ.get(\"LOCAL_RANK\", \"0\"))\n    else:\n        device = \"cpu\"\n\n    return torch.device(device)\n\n\ndef get_device_count() -> int:\n    r\"\"\"\n    Gets the number of available GPU or NPU devices.\n    \"\"\"\n    if is_torch_xpu_available():\n        return torch.xpu.device_count()\n    elif is_torch_npu_available():\n        return torch.npu.device_count()\n    elif is_torch_cuda_available():\n        return torch.cuda.device_count()\n    else:\n        return 0\n\n\ndef get_logits_processor() -> \"LogitsProcessorList\":\n    r\"\"\"\n    Gets logits processor that removes NaN and Inf logits.\n    \"\"\"\n    logits_processor = LogitsProcessorList()\n    logits_processor.append(InfNanRemoveLogitsProcessor())\n    return logits_processor\n\n\ndef get_peak_memory() -> Tuple[int, int]:\n    r\"\"\"\n    Gets the peak memory usage for the current device (in Bytes).\n    \"\"\"\n    if is_torch_npu_available():\n        return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()\n    elif is_torch_cuda_available():\n        return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()\n    else:\n        return 0, 0\n\n\ndef has_tokenized_data(path: \"os.PathLike\") -> bool:\n    r\"\"\"\n    Checks if the path has a tokenized dataset.\n    \"\"\"\n    return os.path.isdir(path) and len(os.listdir(path)) > 0\n\n\ndef infer_optim_dtype(model_dtype: \"torch.dtype\") -> \"torch.dtype\":\n    r\"\"\"\n    Infers the optimal dtype according to the model_dtype and device compatibility.\n    \"\"\"\n    if _is_bf16_available and model_dtype == torch.bfloat16:\n        return torch.bfloat16\n    elif _is_fp16_available:\n        return torch.float16\n    else:\n        return torch.float32\n\n\ndef is_gpu_or_npu_available() -> bool:\n    r\"\"\"\n    Checks if the GPU or NPU is available.\n    \"\"\"\n    return is_torch_npu_available() or is_torch_cuda_available()\n\n\ndef is_env_enabled(env_var: str, default: str = \"0\") -> bool:\n    r\"\"\"\n    Checks if the environment variable is enabled.\n    \"\"\"\n    return os.getenv(env_var, default).lower() in [\"true\", \"y\", \"1\"]\n\n\ndef numpify(inputs: Union[\"NDArray\", \"torch.Tensor\"]) -> \"NDArray\":\n    r\"\"\"\n    Casts a torch tensor or a numpy array to a numpy array.\n    \"\"\"\n    if isinstance(inputs, torch.Tensor):\n        inputs = inputs.cpu()\n        if inputs.dtype == torch.bfloat16:  # numpy does not support bfloat16 until 1.21.4\n            inputs = inputs.to(torch.float32)\n\n        inputs = inputs.numpy()\n\n    return inputs\n\n\ndef skip_check_imports() -> None:\n    r\"\"\"\n    Avoids flash attention import error in custom model files.\n    \"\"\"\n    if not is_env_enabled(\"FORCE_CHECK_IMPORTS\"):\n        transformers.dynamic_module_utils.check_imports = get_relative_imports\n\n\ndef torch_gc() -> None:\n    r\"\"\"\n    Collects GPU or NPU memory.\n    \"\"\"\n    gc.collect()\n    if is_torch_xpu_available():\n        torch.xpu.empty_cache()\n    elif is_torch_npu_available():\n        torch.npu.empty_cache()\n    elif is_torch_mps_available():\n        torch.mps.empty_cache()\n    elif is_torch_cuda_available():\n        torch.cuda.empty_cache()\n\n\ndef try_download_model_from_other_hub(model_args: \"ModelArguments\") -> str:\n    if (not use_modelscope() and not use_openmind()) or os.path.exists(model_args.path):\n        return model_args.path\n\n    if use_modelscope():\n        check_version(\"modelscope>=1.11.0\", mandatory=True)\n        from modelscope import snapshot_download  # type: ignore\n\n        revision = \"master\" if model_args.model_revision == \"main\" else model_args.model_revision\n        return snapshot_download(\n            model_args.path,\n            revision=revision,\n            cache_dir=model_args.cache_dir,\n        )\n\n    if use_openmind():\n        check_version(\"openmind>=0.8.0\", mandatory=True)\n        from openmind.utils.hub import snapshot_download  # type: ignore\n\n        return snapshot_download(\n            model_args.path,\n            revision=model_args.model_revision,\n            cache_dir=model_args.cache_dir,\n        )\n\n\ndef use_modelscope() -> bool:\n    return is_env_enabled(\"USE_MODELSCOPE_HUB\")\n\n\ndef use_openmind() -> bool:\n    return is_env_enabled(\"USE_OPENMIND_HUB\")\n\n\ndef use_ray() -> bool:\n    return is_env_enabled(\"USE_RAY\")\n"
  },
  {
    "path": "siirl/utils/extras/net_utils.py",
    "content": "# Copyright 2023-2024 SGLang Team\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport ipaddress\n\n\ndef is_ipv4(ip_str: str) -> bool:\n    \"\"\"\n    Check if the given string is an IPv4 address\n\n    Args:\n        ip_str: The IP address string to check\n\n    Returns:\n        bool: Returns True if it's an IPv4 address, False otherwise\n    \"\"\"\n    try:\n        ipaddress.IPv4Address(ip_str)\n        return True\n    except ipaddress.AddressValueError:\n        return False\n\n\ndef is_ipv6(ip_str: str) -> bool:\n    \"\"\"\n    Check if the given string is an IPv6 address\n\n    Args:\n        ip_str: The IP address string to check\n\n    Returns:\n        bool: Returns True if it's an IPv6 address, False otherwise\n    \"\"\"\n    try:\n        ipaddress.IPv6Address(ip_str)\n        return True\n    except ipaddress.AddressValueError:\n        return False\n"
  },
  {
    "path": "siirl/utils/extras/packages.py",
    "content": "import importlib.metadata\nimport importlib.util\nfrom functools import lru_cache\nfrom typing import TYPE_CHECKING\n\nfrom packaging import version\n\nif TYPE_CHECKING:\n    from packaging.version import Version\n\n\ndef _is_package_available(name: str) -> bool:\n    return importlib.util.find_spec(name) is not None\n\n\ndef _get_package_version(name: str) -> \"Version\":\n    try:\n        return version.parse(importlib.metadata.version(name))\n    except Exception:\n        return version.parse(\"0.0.0\")\n\n\ndef is_pyav_available():\n    return _is_package_available(\"av\")\n\n\ndef is_librosa_available():\n    return _is_package_available(\"librosa\")\n\n\ndef is_fastapi_available():\n    return _is_package_available(\"fastapi\")\n\n\ndef is_galore_available():\n    return _is_package_available(\"galore_torch\")\n\n\ndef is_apollo_available():\n    return _is_package_available(\"apollo_torch\")\n\n\ndef is_gradio_available():\n    return _is_package_available(\"gradio\")\n\n\ndef is_matplotlib_available():\n    return _is_package_available(\"matplotlib\")\n\n\ndef is_pillow_available():\n    return _is_package_available(\"PIL\")\n\n\ndef is_ray_available():\n    return _is_package_available(\"ray\")\n\n\ndef is_requests_available():\n    return _is_package_available(\"requests\")\n\n\ndef is_rouge_available():\n    return _is_package_available(\"rouge_chinese\")\n\n\ndef is_starlette_available():\n    return _is_package_available(\"sse_starlette\")\n\n\n@lru_cache\ndef is_transformers_version_greater_than(content: str):\n    return _get_package_version(\"transformers\") >= version.parse(content)\n\n\ndef is_uvicorn_available():\n    return _is_package_available(\"uvicorn\")\n\n\ndef is_vllm_available():\n    return _is_package_available(\"vllm\")\n"
  },
  {
    "path": "siirl/utils/extras/patch.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom loguru import logger\nfrom itertools import product\nfrom math_verify.errors import TimeoutException\nfrom math_verify.grader import sympy_expr_eq\n\nfrom sympy import Basic, MatrixBase\nfrom math_verify.utils import timeout\n\n\ndef verify(\n    gold: list[Basic | MatrixBase | str] | Basic | MatrixBase | str,\n    target: list[Basic | MatrixBase | str] | Basic | MatrixBase | str,\n    float_rounding: int = 6,\n    numeric_precision: int = 15,\n    strict: bool = True,\n    timeout_seconds: int = 5,\n) -> bool:\n    \"\"\"Verifies if the target expression matches the gold expression using multiple comparison strategies.\n\n    This function implements a comprehensive comparison system for mathematical expressions,\n    handling various types of mathematical objects (numbers, expressions, sets, matrices, etc.)\n    with multiple fallback strategies.\n\n    Note:\n        - It's expected that both gold and pred has been parsed with math_verify.parse function.\n        - Function is not symmetric, gold answer should be passed as gold and prediction as pred. The non-symmetric nature appears at assignment simplification and equation interval conversion.\n\n    Args:\n        gold: The reference/correct expression(s). Can be:\n            - A single SymPy expression (Basic or MatrixBase)\n            - A string\n            - A list of any of the above\n        target: The expression(s) to verify. Same types as gold.\n        float_rounding: Number of decimal places to round floats to. Defaults to 6.\n        numeric_precision: Number of decimal places to consider for numeric comparisons. Defaults to 15.\n            - If you know the evaluated expressions will be small, you should increase this. See: https://docs.sympy.org/latest/modules/evalf.html\n        strict: Whether to enforce strict comparison mode. Defaults to True.\n            - In strict mode: Variables matter and sets are not comparable with tuples\n            - In non-strict mode: Variables are matched by position and sets can be compared with tuples\n        timeout_seconds: Maximum time in seconds to spend on any single comparison operation.\n            Defaults to 5 seconds.\n\n    Returns:\n        bool: True if target matches gold according to any of the comparison strategies,\n              False otherwise.\n\n    Comparison Strategy:\n        1. String to String comparison\n        2. Numeric expressions: Comparison within specified precision\n        3. Symbolic equality through simplification\n        4. Special handling for:\n            - Relational expressions (equations/inequalities)\n            - Sets and intervals\n            - Matrices and vectors\n            - Complex numbers\n        5. Robust error handling with timeout protection\n\n    Example:\n        >>> verify(sympy.Rational(1, 3), 0.333333)  # Numeric comparison\n        True\n        >>> verify(sympy.Symbol('x') + 1, sympy.Symbol('y') + 1, strict=False)  # Variable matching\n        True\n        >>> verify(sympy.FiniteSet(1, 2), sympy.Tuple(1, 2), strict=False)  # Set-tuple comparison\n        True\n    \"\"\"\n\n    @timeout(timeout_seconds=timeout_seconds)\n    def compare_single_extraction(gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str) -> bool:\n        # If both are sympy expressions, we can use sympy to compare them\n        if isinstance(gold, (Basic, MatrixBase)) and isinstance(target, (Basic, MatrixBase)):\n            return sympy_expr_eq(gold, target, float_rounding, numeric_precision, strict)\n\n        # We don't support str / sympy.Expr comparison. Imo there is no point in doing this, as chances\n        # of this happening are very low.  The only why one of them is not converted to sympy expression\n        # is usually because the parsing logic failed in this case we should improve the parsing logic\n        # instead of somehow fixing adhoc.\n        elif isinstance(gold, str) and isinstance(target, str):\n            # We just do string comparison for everything else\n            gold = gold.strip()\n            target = target.strip()\n\n            # Ensure it's both not empty and equal\n            return len(gold) > 0 and len(target) > 0 and gold == target\n\n        return False\n\n    def compare_single_extraction_wrapper(g, t):\n        try:\n            return compare_single_extraction(g, t)\n        except Exception:\n            #! Do not attempt to print out the g and t during handling of exception\n            # Because a) it can throw an exception itself and b) it can cause it to be stuck forever during str conversion\n            # logger.exception(\"Error during comparison\")\n            return False\n        except TimeoutException:\n            # logger.error(\"Timeout during comparison\")\n            return False\n\n    if not isinstance(gold, list):\n        gold = [gold]\n    if not isinstance(target, list):\n        target = [target]\n\n    return any(compare_single_extraction_wrapper(g, t) for g, t in product(gold, target))\n"
  },
  {
    "path": "siirl/utils/extras/py_functional.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nContain small python utility functions\n\"\"\"\n\nimport importlib\nimport multiprocessing\nimport os\nimport queue  # Import the queue module for exception type hint\nimport signal\nfrom functools import wraps\nfrom types import SimpleNamespace\nfrom typing import Any, Callable, Dict, Iterator, Optional, Tuple\n\n\n# --- Top-level helper for multiprocessing timeout ---\n# This function MUST be defined at the top level to be pickleable\ndef _mp_target_wrapper(target_func: Callable, mp_queue: multiprocessing.Queue, args: Tuple, kwargs: Dict[str, Any]):\n    \"\"\"\n    Internal wrapper function executed in the child process.\n    Calls the original target function and puts the result or exception into the queue.\n    \"\"\"\n    try:\n        result = target_func(*args, **kwargs)\n        mp_queue.put((True, result))  # Indicate success and put result\n    except Exception as e:\n        # Ensure the exception is pickleable for the queue\n        try:\n            import pickle\n\n            pickle.dumps(e)  # Test if the exception is pickleable\n            mp_queue.put((False, e))  # Indicate failure and put exception\n        except (pickle.PicklingError, TypeError):\n            # Fallback if the original exception cannot be pickled\n            mp_queue.put((False, RuntimeError(f\"Original exception type {type(e).__name__} not pickleable: {e}\")))\n\n\n# Renamed the function from timeout to timeout_limit\ndef timeout_limit(seconds: float, use_signals: bool = False):\n    \"\"\"\n    Decorator to add a timeout to a function.\n\n    Args:\n        seconds: The timeout duration in seconds.\n        use_signals: (Deprecated)  This is deprecated because signals only work reliably in the main thread\n                     and can cause issues in multiprocessing or multithreading contexts.\n                     Defaults to False, which uses the more robust multiprocessing approach.\n\n    Returns:\n        A decorated function with timeout.\n\n    Raises:\n        TimeoutError: If the function execution exceeds the specified time.\n        RuntimeError: If the child process exits with an error (multiprocessing mode).\n        NotImplementedError: If the OS is not POSIX (signals are only supported on POSIX).\n    \"\"\"\n\n    def decorator(func):\n        if use_signals:\n            if os.name != \"posix\":\n                raise NotImplementedError(f\"Unsupported OS: {os.name}\")\n            # Issue deprecation warning if use_signals is explicitly True\n            print(\n                \"WARN: The 'use_signals=True' option in the timeout decorator is deprecated. \\\n                Signals are unreliable outside the main thread. \\\n                Please use the default multiprocessing-based timeout (use_signals=False).\"\n            )\n\n            @wraps(func)\n            def wrapper_signal(*args, **kwargs):\n                def handler(signum, frame):\n                    # Update function name in error message if needed (optional but good practice)\n                    raise TimeoutError(f\"Function {func.__name__} timed out after {seconds} seconds (signal)!\")\n\n                old_handler = signal.getsignal(signal.SIGALRM)\n                signal.signal(signal.SIGALRM, handler)\n                # Use setitimer for float seconds support, alarm only supports integers\n                signal.setitimer(signal.ITIMER_REAL, seconds)\n\n                try:\n                    result = func(*args, **kwargs)\n                finally:\n                    # Reset timer and handler\n                    signal.setitimer(signal.ITIMER_REAL, 0)\n                    signal.signal(signal.SIGALRM, old_handler)\n                return result\n\n            return wrapper_signal\n        else:\n            # --- Multiprocessing based timeout (existing logic) ---\n            @wraps(func)\n            def wrapper_mp(*args, **kwargs):\n                q = multiprocessing.Queue(maxsize=1)\n                process = multiprocessing.Process(target=_mp_target_wrapper, args=(func, q, args, kwargs))\n                process.start()\n                process.join(timeout=seconds)\n\n                if process.is_alive():\n                    process.terminate()\n                    process.join(timeout=0.5)  # Give it a moment to terminate\n                    if process.is_alive():\n                        print(f\"Warning: Process {process.pid} did not terminate gracefully after timeout.\")\n                    # Update function name in error message if needed (optional but good practice)\n                    raise TimeoutError(f\"Function {func.__name__} timed out after {seconds} seconds (multiprocessing)!\")\n\n                try:\n                    success, result_or_exc = q.get(timeout=0.1)  # Small timeout for queue read\n                    if success:\n                        return result_or_exc\n                    else:\n                        raise result_or_exc  # Reraise exception from child\n                except queue.Empty as err:\n                    exitcode = process.exitcode\n                    if exitcode is not None and exitcode != 0:\n                        raise RuntimeError(f\"Child process exited with error (exitcode: {exitcode}) before returning result.\") from err\n                    else:\n                        # Should have timed out if queue is empty after join unless process died unexpectedly\n                        # Update function name in error message if needed (optional but good practice)\n                        raise TimeoutError(f\"Operation timed out or process finished unexpectedly without result (exitcode: {exitcode}).\") from err\n                finally:\n                    q.close()\n                    q.join_thread()\n\n            return wrapper_mp\n\n    return decorator\n\n\ndef union_two_dict(dict1: Dict, dict2: Dict):\n    \"\"\"Union two dict. Will throw an error if there is an item not the same object with the same key.\n\n    Args:\n        dict1:\n        dict2:\n\n    Returns:\n\n    \"\"\"\n    for key, val in dict2.items():\n        if key in dict1:\n            assert dict2[key] == dict1[key], f\"{key} in meta_dict1 and meta_dict2 are not the same object\"\n        dict1[key] = val\n\n    return dict1\n\n\ndef append_to_dict(data: Dict, new_data: Dict):\n    \"\"\"Append values from new_data to lists in data.\n\n    For each key in new_data, this function appends the corresponding value to a list\n    stored under the same key in data. If the key doesn't exist in data, a new list is created.\n\n    Args:\n        data (Dict): The target dictionary containing lists as values.\n        new_data (Dict): The source dictionary with values to append.\n\n    Returns:\n        None: The function modifies data in-place.\n    \"\"\"\n    for key, val in new_data.items():\n        if key not in data:\n            data[key] = []\n        data[key].append(val)\n\n\nclass NestedNamespace(SimpleNamespace):\n    \"\"\"A nested version of SimpleNamespace that recursively converts dictionaries to namespaces.\n\n    This class allows for dot notation access to nested dictionary structures by recursively\n    converting dictionaries to NestedNamespace objects.\n\n    Example:\n        config_dict = {\"a\": 1, \"b\": {\"c\": 2, \"d\": 3}}\n        config = NestedNamespace(config_dict)\n        # Access with: config.a, config.b.c, config.b.d\n\n    Args:\n        dictionary: The dictionary to convert to a nested namespace.\n        **kwargs: Additional attributes to set on the namespace.\n    \"\"\"\n\n    def __init__(self, dictionary, **kwargs):\n        super().__init__(**kwargs)\n        for key, value in dictionary.items():\n            if isinstance(value, dict):\n                self.__setattr__(key, NestedNamespace(value))\n            else:\n                self.__setattr__(key, value)\n\n\nclass DynamicEnumMeta(type):\n    def __iter__(cls) -> Iterator[Any]:\n        return iter(cls._registry.values())\n\n    def __contains__(cls, item: Any) -> bool:\n        # allow `name in EnumClass` or `member in EnumClass`\n        if isinstance(item, str):\n            return item in cls._registry\n        return item in cls._registry.values()\n\n    def __getitem__(cls, name: str) -> Any:\n        return cls._registry[name]\n\n    def __reduce_ex__(cls, protocol):\n        # Always load the existing module and grab the class\n        return getattr, (importlib.import_module(cls.__module__), cls.__name__)\n\n    def names(cls):\n        return list(cls._registry.keys())\n\n    def values(cls):\n        return list(cls._registry.values())\n\n\nclass DynamicEnum(metaclass=DynamicEnumMeta):\n    _registry: Dict[str, \"DynamicEnum\"] = {}\n    _next_value: int = 0\n\n    def __init__(self, name: str, value: int):\n        self.name = name\n        self.value = value\n\n    def __repr__(self):\n        return f\"<{self.__class__.__name__}.{self.name}: {self.value}>\"\n\n    def __reduce_ex__(self, protocol):\n        \"\"\"\n        Unpickle via: getattr(import_module(module).Dispatch, 'ONE_TO_ALL')\n        so the existing class is reused instead of re-executed.\n        \"\"\"\n        module = importlib.import_module(self.__class__.__module__)\n        enum_cls = getattr(module, self.__class__.__name__)\n        return getattr, (enum_cls, self.name)\n\n    @classmethod\n    def register(cls, name: str) -> \"DynamicEnum\":\n        key = name.upper()\n        if key in cls._registry:\n            raise ValueError(f\"{key} already registered\")\n        member = cls(key, cls._next_value)\n        cls._registry[key] = member\n        setattr(cls, key, member)\n        cls._next_value += 1\n        return member\n\n    @classmethod\n    def remove(cls, name: str):\n        key = name.upper()\n        member = cls._registry.pop(key)\n        delattr(cls, key)\n        return member\n\n    @classmethod\n    def from_name(cls, name: str) -> Optional[\"DynamicEnum\"]:\n        return cls._registry.get(name.upper())\n\n\ndef convert_to_regular_types(obj):\n    \"\"\"Convert Hydra configs and other special types to regular Python types.\"\"\"\n    from omegaconf import DictConfig, ListConfig\n\n    if isinstance(obj, (ListConfig, DictConfig)):\n        return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj)\n    elif isinstance(obj, (list, tuple)):\n        return [convert_to_regular_types(x) for x in obj]\n    elif isinstance(obj, dict):\n        return {k: convert_to_regular_types(v) for k, v in obj.items()}\n    return obj\n"
  },
  {
    "path": "siirl/utils/extras/ray_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nContains commonly used utilities for ray\n\"\"\"\n\nimport concurrent.futures\nimport os\nfrom typing import Any, List, Optional\n\nimport ray\n\n\ndef ray_noset_visible_devices(env_vars=os.environ):\n    # Refer to\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/amd_gpu.py#L102-L103\n    # https://github.com/ray-project/ray/blob/3b9e729f6a669ffd85190f901f5e262af79771b0/python/ray/_private/accelerators/amd_gpu.py#L114-L115\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/npu.py#L94-L95\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/hpu.py#L116-L117\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/neuron.py#L108-L109\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/tpu.py#L171-L172\n    # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/intel_gpu.py#L97-L98\n    NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [\n        \"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES\",\n        \"RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES\",\n        \"RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES\",\n        \"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES\",\n        \"RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES\",\n        \"RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES\",\n        \"RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS\",\n        \"RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR\",\n    ]\n    return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST)\n\n\ndef parallel_put(data_list: List[Any], max_workers: Optional[int] = None):\n    \"\"\"\n    Puts a list of data into the Ray object store in parallel using a thread pool.\n\n    Args:\n        data_list (List[Any]): A list of Python objects to be put into the Ray object store.\n        max_workers (int, optional): The maximum number of worker threads to use.\n                                     Defaults to min(len(data_list), 16).\n\n    Returns:\n        List[ray.ObjectRef]: A list of Ray object references corresponding to the input data_list,\n                             maintaining the original order.\n    \"\"\"\n    assert len(data_list) > 0, \"data_list must not be empty\"\n\n    def put_data(index, data):\n        return index, ray.put(data)\n\n    if max_workers is None:\n        max_workers = min(len(data_list), 16)\n\n    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:\n        data_list_f = [executor.submit(put_data, i, data) for i, data in enumerate(data_list)]\n        res_lst = []\n        for future in concurrent.futures.as_completed(data_list_f):\n            res_lst.append(future.result())\n\n        # reorder based on index\n        output = [None for _ in range(len(data_list))]\n        for res in res_lst:\n            index, data_ref = res\n            output[index] = data_ref\n\n    return output\n"
  },
  {
    "path": "siirl/utils/import_string.py",
    "content": "# Copyright 2025, Infrawaves. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib\n\n\ndef import_string(import_name: str):\n    \"\"\"Dynamically imports a module or object from a string.\"\"\"\n    module_name, obj_name = import_name.rsplit(\".\", 1)\n    try:\n        module = importlib.import_module(module_name)\n        return getattr(module, obj_name)\n    except (ImportError, AttributeError) as e:\n        raise ImportError(f\"Could not import {import_name}\") from e\n\nif __name__ == \"__main__\":\n    print(import_string(\"siirl.engine.sharding_manager.fsdp_vllm.MultiAgentFSDPVLLMShardingManager\"))"
  },
  {
    "path": "siirl/utils/kernel/__init__.py",
    "content": "#\n# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "siirl/utils/kernel/kernels.py",
    "content": "#\n# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nImplementations of the linear cross entropy with token entropy kernel.\n\"\"\"\n\nimport typing\nfrom dataclasses import dataclass\n\nimport torch\nimport torch.distributed as dist\n\ntry:\n    import triton\n    import triton.language as tl\n\n    HAVE_TRITON = True\nexcept ImportError:\n    HAVE_TRITON = False\n\nfrom siirl.utils.extras.device import get_torch_device\n\nif not HAVE_TRITON:\n    from contextlib import contextmanager\n    from unittest.mock import MagicMock\n\n    @contextmanager\n    def null_decorator(*args, **kwargs):\n        if len(kwargs) == 0 and len(args) == 1 and callable(args[0]):\n            return args[0]\n        else:\n\n            def inner(func):\n                return func\n\n            return inner\n\n    triton = MagicMock()\n    triton.jit = null_decorator\n    triton.autotune = null_decorator\n    tl = MagicMock()\n\n\n@dataclass\nclass EntropyReductionEnum:\n    \"\"\"\n    Enum for the reduction method of cross entropy.\n    \"\"\"\n\n    _None = 0\n    _Sum = 1\n    _Mean = 2\n\n\ndef get_entropy_reduction_enum_number(reduction: str) -> int:\n    \"\"\"\n    Get the enum number for the reduction method of cross entropy.\n    \"\"\"\n    _enum = EntropyReductionEnum._None\n    if reduction == \"none\":\n        _enum = EntropyReductionEnum._None\n    elif reduction == \"sum\":\n        _enum = EntropyReductionEnum._Sum\n    elif reduction == \"mean\":\n        _enum = EntropyReductionEnum._Mean\n    else:\n        raise ValueError(f\"Invalid reduction: {reduction}\")\n    return _enum\n\n\ndef get_entropy_reduction_enum(ce_reduction: int) -> EntropyReductionEnum:\n    \"\"\"\n    Get the enum for the reduction method of cross entropy.\n    \"\"\"\n    _enum = EntropyReductionEnum._None\n    if ce_reduction == 0:\n        _enum = EntropyReductionEnum._None\n    elif ce_reduction == 1:\n        _enum = EntropyReductionEnum._Sum\n    elif ce_reduction == 2:\n        _enum = EntropyReductionEnum._Mean\n    else:\n        raise ValueError(f\"Invalid ce_reduction: {ce_reduction}\")\n    return _enum\n\n\n@dataclass\nclass BackwardEnum:\n    \"\"\"\n    Enum for the backward method.\n    \"\"\"\n\n    _Total_Fuse_MN = (\n        0  # Fuse d_logits & d_hidden & d_weight, no intermediate storage, requires fp32 for d_hidden & d_weight\n    )\n    _Total_Separate = 1  # Store d_logits, no special requirements for d_hidden & d_weight\n    _Split_Dlogits_N = 2  # split d_logits along its N dimension, aka. vocab_size\n    _Split_Dlogits_M = 3  # split d_logits along its M dimension, aka. num_tokens\n\n\n@dataclass\nclass Config:\n    \"\"\"Configuration for efficient entropy kernel operations.\n\n    Args:\n        _backward (BackwardEnum): Backward computation method. Defaults to BackwardEnum._Split_Dlogits_N.\n        _use_triton (bool): Whether to use Triton kernels for computation. Defaults to True.\n    \"\"\"\n\n    _backward: BackwardEnum = BackwardEnum._Split_Dlogits_N\n    _use_triton: bool = True\n\n\n_config = Config()\n\n\ndef set_backward_method(backward_method: BackwardEnum):\n    \"\"\"\n    Set the backward method.\n    \"\"\"\n    global _config\n    _config._backward = backward_method\n\n\n@triton.autotune(\n    configs=[triton.Config({\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 32}, num_stages=3, num_warps=8)],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_kernel_general_mainloop(\n    rank,\n    hidden_ptr,\n    weight_ptr,\n    labels_ptr,\n    num_tokens,\n    hidden_size,\n    vocab_size,\n    vocab_per_split,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    max_ptr,\n    stride_max_m: tl.int64,\n    stride_max_n: tl.int64,\n    accu_ptr,\n    stride_accu_m: tl.int64,\n    stride_accu_n: tl.int64,\n    entropy_b_ptr,\n    stride_entropy_b_m: tl.int64,\n    stride_entropy_b_n: tl.int64,\n    global_logprobs_ptr,\n    stride_global_logprobs: tl.int64,\n    global_logprobs_scalar_ptr,\n    rcp_temperature: tl.float32,\n    # Meta-parameters\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n):\n    \"\"\"\n    forward mainloop\n    \"\"\"\n    pid = tl.program_id(axis=0)\n    num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split\n    num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N)\n    pid_m = pid % num_pid_m\n    pid_n = pid // num_pid_m\n\n    if pid_m == 0 and pid_n == 0:\n        tl.store(global_logprobs_scalar_ptr, 0.0)\n\n    # create pointers for the first blocks of hidden\n    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n\n    # load labels for this block\n    labels = tl.load(labels_ptr + offs_am, mask=offs_am < num_tokens)\n\n    # traverse over N dimension\n    # _max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    _max = tl.full((BLOCK_SIZE_M,), -float(\"inf\"), dtype=tl.float32)\n    _accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    _entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    _logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    for n in range(0, num_pid_n):\n        offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n        # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n)\n        weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n\n        # iterate over K dimension\n        logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n        for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n            # load the next block of hidden and weight\n            _hidden = tl.load(\n                hidden_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n                other=0.0,\n            )\n            # _weight = tl.load(weight_ptrs,\n            #                   mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < (min(\n            #                       (pid_n + 1) * vocab_per_split, vocab_size))),\n            #                   other=0.0)\n\n            _weight = tl.load(\n                weight_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K)\n                & (offs_bn[:, None] < (min((pid_n + 1) * vocab_per_split, vocab_size))),\n                other=0.0,\n            )\n\n            # GEMM\n            logits = tl.dot(_hidden, _weight.trans(), logits)\n\n            # advance the ptrs to the next K block\n            hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n            weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n        # reset hidden_ptrs for next iteration\n        hidden_ptrs -= hidden_size * stride_hidden_k\n\n        # scale logits by temperature\n        logits *= rcp_temperature\n\n        # update global maximum\n        _max_old = _max\n        m_pid_n = tl.max(logits, axis=1)\n        _max = tl.maximum(_max_old, m_pid_n)\n\n        exp_logits = tl.exp(logits - _max[:, None])\n        coeff = tl.exp(_max_old - _max)\n        _accu = coeff * _accu + tl.sum(exp_logits, axis=1)\n\n        _entropy_b = _entropy_b * coeff + tl.sum(logits * exp_logits, axis=1)\n\n        label_mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None]\n        _logprobs += tl.sum(logits * label_mask, axis=1)\n\n    # store maximum\n    offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_max_n = pid_n\n    maximum_ptrs = max_ptr + offs_max_n * stride_max_n + offs_max_m * stride_max_m\n    tl.store(maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits))\n\n    # store entropy\n    accu_ptrs = accu_ptr + offs_max_n * stride_accu_n + offs_max_m * stride_accu_m\n    tl.store(accu_ptrs, _accu, mask=(offs_max_m < num_tokens) & (offs_max_n[None] < num_splits))\n    entropy_b_ptrs = entropy_b_ptr + offs_max_n * stride_entropy_b_n + offs_max_m * stride_entropy_b_m\n    tl.store(entropy_b_ptrs, _entropy_b, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits))\n\n    # store logprobs\n    vocab_left_idx = pid_n * vocab_per_split + rank * vocab_size\n    vocab_right_idx = min((pid_n + 1) * vocab_per_split, vocab_size) + rank * vocab_size\n    mask = (labels >= vocab_left_idx) & (labels < vocab_right_idx)\n    mask &= offs_am < num_tokens\n    global_logprobs_ptrs = global_logprobs_ptr + offs_am * stride_global_logprobs\n    # tl.atomic_add(global_logprobs_ptrs, _logprobs, mask=mask)\n    tl.store(global_logprobs_ptrs, _logprobs, mask=mask)\n\n\n@triton.autotune(configs=[triton.Config({\"BLOCK_SIZE_M\": 16, \"BLOCK_SIZE_N\": 64})], key=[\"num_tokens\", \"num_splits\"])\n@triton.jit\ndef efficient_entropy_triton_kernel_epilogue(\n    max_ptr,\n    stride_max_m: tl.int64,\n    stride_max_n: tl.int64,\n    num_tokens,\n    num_splits,\n    global_max_ptr,\n    stride_global_max: tl.int64,\n    accu_ptr,\n    stride_accu_m: tl.int64,\n    stride_accu_n: tl.int64,\n    global_accu_ptr,\n    stride_global_accu: tl.int64,\n    entropy_b_ptr,\n    stride_entropy_b_m: tl.int64,\n    stride_entropy_b_n: tl.int64,\n    global_entropy_b_ptr,\n    stride_global_entropy_b: tl.int64,\n    global_entropy_ptr,\n    stride_global_entropy: tl.int64,\n    global_logprobs_ptr,\n    stride_global_logprobs: tl.int64,\n    global_logprobs_scalar_ptr,\n    reduction: int,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n):\n    \"\"\"\n    foward epilogue\n    \"\"\"\n    pid_m = tl.program_id(axis=0)\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)):\n        offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n        max_ptrs = max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n\n\n        _max = tl.load(max_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0)\n\n        accu_ptrs = accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n\n        _accu = tl.load(accu_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0)\n\n        entropy_b_ptrs = entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n\n        _entropy_b = tl.load(\n            entropy_b_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0\n        )\n\n        # local reduction\n        _max_old = global_max\n        _local_max = tl.max(_max, axis=1)\n        global_max = tl.maximum(global_max, _local_max)\n\n        _scale = tl.exp(_max - global_max[:, None])\n        _coeff = tl.exp(_max_old - global_max)\n        global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1)\n        global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1)\n\n    # store\n    maximum_ptrs = global_max_ptr + offs_m * stride_global_max\n    tl.store(maximum_ptrs, global_max, mask=offs_m < num_tokens)\n\n    # store entropy_b\n    global_entropy_b = tl.fdiv(global_entropy_b, global_accu)  # entropy_b\n    tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens)\n\n    # store entropy\n    global_accu_ptrs = global_accu_ptr + offs_m * stride_global_accu\n    tl.store(global_accu_ptrs, global_accu, mask=offs_m < num_tokens)\n    global_entropy = tl.log(global_accu) + global_max - global_entropy_b  # entropy_a\n    global_entropy_ptrs = global_entropy_ptr + offs_m * stride_global_entropy\n    tl.store(global_entropy_ptrs, global_entropy, mask=offs_m < num_tokens)\n    # update logprobs\n    global_logprobs_ptrs = global_logprobs_ptr + offs_m * stride_global_logprobs\n    global_logprobs = tl.load(global_logprobs_ptrs, mask=offs_m < num_tokens)\n    global_logprobs = global_max + tl.log(global_accu) - global_logprobs\n\n    global_logprobs = -1 * global_logprobs\n    if reduction == 0:\n        tl.store(global_logprobs_ptrs, global_logprobs, mask=offs_m < num_tokens)\n    elif reduction == 1:\n        global_logprobs_scalar = tl.sum(global_logprobs, axis=0)\n        tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar)\n    elif reduction == 2:\n        global_logprobs_scalar = tl.sum(global_logprobs, axis=0) / num_tokens.to(tl.float32)\n        tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar)\n\n\n@triton.autotune(configs=[triton.Config({\"BLOCK_SIZE_M\": 16, \"BLOCK_SIZE_N\": 64})], key=[\"num_tokens\", \"num_splits\"])\n@triton.jit\ndef efficient_entropy_triton_kernel_epilogue_tp(\n    num_tokens,\n    num_splits,\n    reduced_max_ptr,\n    stride_reduced_max_m: tl.int64,\n    stride_reduced_max_n: tl.int64,\n    original_max_ptr,\n    stride_original_max_m: tl.int64,\n    stride_original_max_n: tl.int64,\n    accu_ptr,\n    stride_accu_m: tl.int64,\n    stride_accu_n: tl.int64,\n    entropy_b_ptr,\n    stride_entropy_b_m: tl.int64,\n    stride_entropy_b_n: tl.int64,\n    global_max_ptr,\n    stride_global_max: tl.int64,\n    global_accu_ptr,\n    stride_global_accu: tl.int64,\n    global_entropy_b_ptr,\n    stride_global_entropy_b: tl.int64,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n):\n    pid_m = tl.program_id(axis=0)\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n\n    global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)):\n        offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n        _reduced_max = tl.load(\n            reduced_max_ptr + offs_m[:, None] * stride_reduced_max_m + offs_n[None, :] * stride_reduced_max_n,\n            mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits),\n            other=0.0,\n        )\n        _original_max = tl.load(\n            original_max_ptr + offs_m[:, None] * stride_original_max_m + offs_n[None, :] * stride_original_max_n,\n            mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits),\n            other=0.0,\n        )\n        _accu = tl.load(\n            accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n,\n            mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits),\n            other=0.0,\n        )\n\n        # local reduce-max\n        _max_old = global_max\n        _local_max = tl.max(_reduced_max, axis=1)\n        global_max = tl.maximum(global_max, _local_max)\n\n        # update accumulate\n        _coeff = tl.exp(_max_old - global_max)\n        _scale = tl.exp(_original_max - global_max[:, None])\n        global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1)\n\n        # update entropy_b\n        _entropy_b = tl.load(\n            entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n,\n            mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits),\n            other=0.0,\n        )\n        global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1)\n\n    # store\n    tl.store(global_max_ptr + offs_m * stride_global_max, global_max, mask=offs_m < num_tokens)\n    tl.store(global_accu_ptr + offs_m * stride_global_accu, global_accu, mask=offs_m < num_tokens)\n    tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens)\n\n\n@triton.autotune(configs=[triton.Config({\"BLOCK_SIZE_M\": 16})], key=[\"num_tokens\"])\n@triton.jit\ndef efficient_entropy_triton_epilogue_tp_update(\n    num_tokens,\n    logprobs_ptr,\n    stride_logprobs: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accumulate_ptr,\n    stride_accumulate: tl.int64,\n    entropy_b_ptr,\n    stride_entropy_b: tl.int64,\n    entropy_ptr,\n    stride_entropy: tl.int64,\n    logprobs_scalar_ptr,\n    reduction: int,\n    BLOCK_SIZE_M: tl.constexpr,\n):\n    pid_m = tl.program_id(axis=0)\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n\n    maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens)\n    accumulate = tl.load(accumulate_ptr + offs_m * stride_accumulate, mask=offs_m < num_tokens)\n\n    entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens)\n    entropy_b = tl.fdiv(entropy_b, accumulate)\n    tl.store(entropy_b_ptr + offs_m * stride_entropy_b, entropy_b, mask=offs_m < num_tokens)\n\n    entropy = tl.log(accumulate) + maximum - entropy_b\n    tl.store(entropy_ptr + offs_m * stride_entropy, entropy, mask=offs_m < num_tokens)\n\n    logprobs = tl.load(logprobs_ptr + offs_m * stride_logprobs, mask=offs_m < num_tokens)\n    logprobs = maximum + tl.log(accumulate) - logprobs\n\n    logprobs = -1 * logprobs\n    if reduction == 0:\n        tl.store(logprobs_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens)\n    elif reduction == 1:\n        logprobs_scalar = tl.sum(logprobs, axis=0)\n        tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar)\n    elif reduction == 2:\n        logprobs_scalar = tl.sum(logprobs, axis=0) / num_tokens.to(tl.float32)\n        tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar)\n\n\n_dedicated_stream, _dedicated_events = None, None\n\n\ndef efficient_entropy_forward(\n    hidden: torch.Tensor,\n    weight: torch.Tensor,\n    labels: torch.Tensor,\n    reduction: typing.Optional[int] = 2,\n    temperature: typing.Optional[float] = 1.0,\n    dist_process_group: typing.Optional[dist.ProcessGroup] = None,\n) -> list[torch.Tensor]:\n    \"\"\"\n    forward host function\n    \"\"\"\n    assert hidden.is_cuda and weight.is_cuda and labels.is_cuda\n    assert weight.device == hidden.device and labels.device == hidden.device\n    assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1\n    assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous()\n\n    assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1]\n\n    _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group)\n    _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group)\n\n    if dist_process_group is not None and not hasattr(efficient_entropy_forward, \"_initialized\"):\n        global _dedicated_stream, _dedicated_events\n        _dedicated_stream = get_torch_device().Stream(hidden.device)\n        _dedicated_events = [get_torch_device().Event() for _ in range(2)]\n        efficient_entropy_forward._initialized = True\n\n    num_tokens, hidden_size = hidden.shape\n    num_tokens = labels.shape[0]\n    vocab_size, hidden_size = weight.shape\n    assert hidden_size % 128 == 0\n\n    REDUCTION = get_entropy_reduction_enum(reduction)\n\n    if REDUCTION == EntropyReductionEnum._None:\n        if dist_process_group is None:\n            logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32)\n        else:\n            logprobs = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32)\n    elif REDUCTION in (EntropyReductionEnum._Sum, EntropyReductionEnum._Mean):\n        logprobs = torch.empty((), device=hidden.device, dtype=torch.float32)\n    else:\n        raise ValueError(f\"Invalid reduction: {reduction}\")\n\n    entropy = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32)\n    assert logprobs.is_contiguous() and entropy.is_contiguous()\n\n    maximum = torch.empty_like(entropy)\n    accumulate_and_entropy_b = torch.empty((num_tokens * 2,), device=hidden.device, dtype=torch.float32)\n    accumulate_and_entropy_b_view = accumulate_and_entropy_b.view(2, num_tokens)\n    accumulate = accumulate_and_entropy_b_view[0, :]\n    entropy_b = accumulate_and_entropy_b_view[1, :]\n    assert maximum.is_contiguous() and accumulate.is_contiguous() and entropy_b.is_contiguous()\n\n    vocab_per_split = 1024\n    assert vocab_per_split % 128 == 0\n    num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split\n\n    _max = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32)\n    _accu = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32)\n    _entropy_b = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32)\n\n    if REDUCTION == EntropyReductionEnum._None:\n        _logprobs = logprobs\n    else:\n        _logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32)\n\n    assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous()\n    assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda\n\n    if _config._use_triton:\n        # 1D kernel launch, then split the tile\n        def mainloop_grid(meta):\n            return (triton.cdiv(num_tokens, meta[\"BLOCK_SIZE_M\"]) * num_splits,)\n\n        efficient_entropy_kernel_general_mainloop[mainloop_grid](\n            _rank,\n            hidden,\n            weight,\n            labels,\n            num_tokens,\n            hidden_size,\n            vocab_size,\n            vocab_per_split,\n            hidden.stride(0),\n            hidden.stride(1),\n            weight.stride(0),\n            weight.stride(1),\n            _max,\n            _max.stride(0),\n            _max.stride(1),\n            _accu,\n            _accu.stride(0),\n            _accu.stride(1),\n            _entropy_b,\n            _entropy_b.stride(0),\n            _entropy_b.stride(1),\n            _logprobs,\n            _logprobs.stride(0),\n            logprobs,\n            1.0 / temperature,\n        )\n    else:\n        raise AssertionError(\"Triton is required for efficient entropy kernel\")\n\n    # reduction on maximum and maximum_indices\n    def epilogue_grid(meta):\n        return (triton.cdiv(num_tokens, meta[\"BLOCK_SIZE_M\"]),)\n\n    if dist_process_group is None:\n        efficient_entropy_triton_kernel_epilogue[epilogue_grid](\n            _max,\n            _max.stride(0),\n            _max.stride(1),\n            num_tokens,\n            num_splits,\n            maximum,\n            maximum.stride(0),\n            _accu,\n            _accu.stride(0),\n            _accu.stride(1),\n            accumulate,\n            accumulate.stride(0),\n            _entropy_b,\n            _entropy_b.stride(0),\n            _entropy_b.stride(1),\n            entropy_b,\n            entropy_b.stride(0),\n            entropy,\n            entropy.stride(0),\n            _logprobs,\n            _logprobs.stride(0),\n            logprobs,\n            REDUCTION,\n        )\n    else:\n        # tensor-parallel\n        _max_backup = _max.clone()\n        dist.all_reduce(_max, op=dist.ReduceOp.MAX, group=dist_process_group)\n\n        get_torch_device().current_stream().record_event(_dedicated_events[0])\n        with get_torch_device().stream(_dedicated_stream):\n            _dedicated_stream.wait_event(_dedicated_events[0])\n            dist.all_reduce(_logprobs, op=dist.ReduceOp.SUM, group=dist_process_group)\n            _dedicated_stream.record_event(_dedicated_events[1])\n\n        efficient_entropy_triton_kernel_epilogue_tp[epilogue_grid](\n            num_tokens,\n            num_splits,\n            _max,\n            _max.stride(0),\n            _max.stride(1),\n            _max_backup,\n            _max_backup.stride(0),\n            _max_backup.stride(1),\n            _accu,\n            _accu.stride(0),\n            _accu.stride(1),\n            _entropy_b,\n            _entropy_b.stride(0),\n            _entropy_b.stride(1),\n            maximum,\n            maximum.stride(0),\n            accumulate,\n            accumulate.stride(0),\n            entropy_b,\n            entropy_b.stride(0),\n        )\n        get_torch_device().current_stream().wait_event(_dedicated_events[1])\n\n        dist.all_reduce(accumulate_and_entropy_b, op=dist.ReduceOp.SUM, group=dist_process_group)\n\n        # update logprobs & entropy\n        efficient_entropy_triton_epilogue_tp_update[epilogue_grid](\n            num_tokens,\n            _logprobs,\n            _logprobs.stride(0),\n            maximum,\n            maximum.stride(0),\n            accumulate,\n            accumulate.stride(0),\n            entropy_b,\n            entropy_b.stride(0),\n            entropy,\n            entropy.stride(0),\n            logprobs,\n            REDUCTION,\n        )\n\n    return (logprobs, entropy, maximum, accumulate, entropy_b)\n\n\n# NOTE: merge d_weight & d_hidden here, split along M & N\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 128, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 16},\n            num_stages=3,\n            num_warps=8,\n        )\n    ],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_backward_kernel_general_mainloop_MN(\n    num_tokens: int,\n    hidden_size: int,\n    vocab_size: int,\n    rank: int,\n    hidden_ptr,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    weight_ptr,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    labels_ptr,\n    stride_labels: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accu_ptr,\n    stride_accu: tl.int64,\n    d_entropy_ptr,\n    stride_d_entropy: tl.int64,\n    d_logprobs_ptr,\n    stride_d_logprobs: tl.int64,\n    reduction: int,\n    entropy_b_ptr,\n    stride_entropy_b: tl.int64,\n    d_hidden_ptr,\n    stride_d_hidden_m: tl.int64,\n    stride_d_hidden_k: tl.int64,\n    d_weight_ptr,\n    stride_d_weight_n: tl.int64,\n    stride_d_weight_k: tl.int64,\n    rcp_temperature: tl.float32,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n):\n    \"\"\"\n    backward mainloop, where d_logits & d_hidden & d_weight are fused\n    \"\"\"\n    # block swizzling\n    # pid = tl.program_id(axis=0)\n    # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    # pid_m = pid % num_pid_m\n    # pid_n = pid // num_pid_m\n\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N)\n    num_pid_in_group = GROUP_SIZE_M * num_pid_n\n    group_id = pid // num_pid_in_group\n    first_pid_m = group_id * GROUP_SIZE_M\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n    pid_n = (pid % num_pid_in_group) // group_size_m\n\n    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n\n    maximum_ptrs = maximum_ptr + offs_am * stride_maximum\n    maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0)\n    accu_ptrs = accu_ptr + offs_am * stride_accu\n    accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6)  # epsilon to avoid division by zero\n    accu_rcp = tl.fdiv(1.0, accu)\n\n    d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy\n    d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0)\n    if reduction == 0:  # none\n        d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs\n        d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0)\n    elif reduction == 1:  # sum\n        d_logprobs = tl.load(d_logprobs_ptr)\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    else:  # mean\n        d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32))\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    d_logprobs = -1 * d_logprobs\n\n    entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b\n    entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0)\n\n    hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n    # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n)\n    weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n    labels_ptrs = labels_ptr + offs_am * stride_labels\n    labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0)\n\n    d_hidden_ptrs = d_hidden_ptr + offs_am[:, None] * stride_d_hidden_m + offs_k[None, :] * stride_d_hidden_k\n    # d_weight_ptrs = d_weight_ptr + offs_k[:, None] * stride_d_weight_k + offs_bn[None, :] * stride_d_weight_n\n    d_weight_ptrs = d_weight_ptr + offs_bn[:, None] * stride_d_weight_n + offs_k[None, :] * stride_d_weight_k\n\n    logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n        _hidden = tl.load(\n            hidden_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n            other=0.0,\n        )\n        # _weight = tl.load(weight_ptrs,\n        #                   mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size),\n        #                   other=0.0)\n        _weight = tl.load(\n            weight_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size),\n            other=0.0,\n        )\n\n        logits = tl.dot(_hidden, _weight.trans(), logits)\n\n        hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n        weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n    hidden_ptrs -= hidden_size * stride_hidden_k\n    weight_ptrs -= hidden_size * stride_weight_k\n\n    # scale logits by temperature\n    logits *= rcp_temperature\n\n    exp_logits = tl.exp(logits - maximum[:, None])\n\n    mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None]\n    d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask)\n    d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None])\n\n    # scale d_logits by temperature\n    d_logits *= rcp_temperature\n\n    # loop for d_weight & d_hidden\n    for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n        _hidden = tl.load(\n            hidden_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n            other=0.0,\n        )\n        # _d_weight = tl.dot(tl.trans(_hidden).to(tl.float32), d_logits)\n        # tl.atomic_add(d_weight_ptrs,\n        #               _d_weight,\n        #               mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size))\n        _d_weight = tl.dot(d_logits.trans(), _hidden.to(tl.float32))\n        tl.atomic_add(\n            d_weight_ptrs,\n            _d_weight,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size),\n        )\n\n        # _weight = tl.load(weight_ptrs,\n        #                   mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size),\n        #                   other=0.0)\n        # _d_hidden = tl.dot(d_logits, tl.trans(_weight).to(tl.float32))\n        _weight = tl.load(\n            weight_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size),\n            other=0.0,\n        )\n        _d_hidden = tl.dot(d_logits, _weight.to(tl.float32))\n        tl.atomic_add(\n            d_hidden_ptrs,\n            _d_hidden,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n        )\n\n        hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n        weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n        d_hidden_ptrs += BLOCK_SIZE_K * stride_d_hidden_k\n        d_weight_ptrs += BLOCK_SIZE_K * stride_d_weight_k\n\n\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 128, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 16},\n            num_stages=3,\n            num_warps=8,\n        ),\n    ],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_backward_kernel_d_hidden(\n    num_tokens: int,\n    hidden_size: int,\n    vocab_size: int,\n    rank: int,\n    hidden_ptr,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    weight_ptr,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    labels_ptr,\n    stride_labels: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accu_ptr,\n    stride_accu: tl.int64,\n    d_entropy_ptr,\n    stride_d_entropy: tl.int64,\n    d_logprobs_ptr,\n    stride_d_logprobs: tl.int64,\n    reduction: int,\n    entropy_b_ptr,\n    stride_entropy_b: tl.int64,\n    d_hidden_ptr,\n    stride_d_hidden_m: tl.int64,\n    stride_d_hidden_k: tl.int64,\n    rcp_temperature: tl.float32,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n):\n    \"\"\"\n    backward d_hidden\n    \"\"\"\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    pid_m = pid % num_pid_m\n    pid_k = pid // num_pid_m\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    result_offs_k = pid_k * BLOCK_SIZE_K + offs_k\n\n    maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0)\n    accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6)\n    accu_rcp = tl.fdiv(1.0, accu)\n    d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0)\n    if reduction == 0:\n        d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0)\n    elif reduction == 1:\n        d_logprobs = tl.load(d_logprobs_ptr)\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    else:\n        d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32))\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    d_logprobs = -1 * d_logprobs\n\n    entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0)\n    labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0)\n\n    # iterate over vocab_size\n    d_hidden = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n    for n in range(0, tl.cdiv(vocab_size, BLOCK_SIZE_N)):\n        offs_n = n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n        hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n        weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n\n        # iterate over hidden_size to get logits\n        logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n        for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n            _hidden = tl.load(\n                hidden_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens),\n                other=0.0,\n            )\n            _weight = tl.load(\n                weight_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size),\n                other=0.0,\n            )\n\n            logits = tl.dot(_hidden, _weight.trans(), logits)\n\n            hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n            weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n\n        # scale logits by temperature\n        logits *= rcp_temperature\n\n        exp_logits = tl.exp(logits - maximum[:, None])\n\n        mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None]\n        d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask)\n        d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None])\n\n        # scale d_logits\n        d_logits *= rcp_temperature\n\n        # calculate d_hidden\n        weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + result_offs_k[None, :] * stride_weight_k)\n        _weight = tl.load(\n            weight_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_n[:, None] < vocab_size), other=0.0\n        )\n        d_hidden = tl.dot(d_logits.to(weight_ptr.dtype.element_ty), _weight, d_hidden)\n\n    # write back\n    tl.store(\n        d_hidden_ptr + offs_m[:, None] * stride_d_hidden_m + result_offs_k[None, :] * stride_d_hidden_k,\n        d_hidden,\n        mask=(offs_m[:, None] < num_tokens) & (result_offs_k[None, :] < hidden_size),\n    )\n\n\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 128, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 16},\n            num_stages=3,\n            num_warps=8,\n        ),\n    ],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_backward_kernel_d_weight(\n    num_tokens: int,\n    hidden_size: int,\n    vocab_size: int,\n    rank: int,\n    hidden_ptr,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    weight_ptr,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    labels_ptr,\n    stride_labels: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accu_ptr,\n    stride_accu: tl.int64,\n    d_entropy_ptr,\n    stride_d_entropy: tl.int64,\n    d_logprobs_ptr,\n    stride_d_logprobs: tl.int64,\n    reduction: int,\n    entropy_b_ptr,\n    stride_entropy_b: tl.int64,\n    d_weight_ptr,\n    stride_d_weight_n: tl.int64,\n    stride_d_weight_k: tl.int64,\n    rcp_temperature: tl.float32,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n):\n    pid = tl.program_id(axis=0)\n    num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N)\n    pid_n = pid % num_pid_n\n    pid_k = pid // num_pid_n\n\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    result_offs_k = pid_k * BLOCK_SIZE_K + offs_k\n\n    d_weight = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)\n    for m in range(0, tl.cdiv(num_tokens, BLOCK_SIZE_M)):\n        offs_m = m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n\n        maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0)\n        accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6)\n        accu_rcp = tl.fdiv(1.0, accu)\n        d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0)\n        if reduction == 0:\n            d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0)\n        elif reduction == 1:\n            d_logprobs = tl.load(d_logprobs_ptr)\n            d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n        else:\n            d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32))\n            d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n        d_logprobs = -1 * d_logprobs\n\n        entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0)\n        labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0)\n\n        hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n        weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n\n        logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n        for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n            _hidden = tl.load(\n                hidden_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens),\n                other=0.0,\n            )\n            _weight = tl.load(\n                weight_ptrs,\n                mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size),\n                other=0.0,\n            )\n\n            logits = tl.dot(_hidden, _weight.trans(), logits)\n\n            hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n            weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n\n        logits *= rcp_temperature\n\n        exp_logits = tl.exp(logits - maximum[:, None])\n\n        mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None]\n        d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask)\n        d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None])\n\n        d_logits *= rcp_temperature\n\n        hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + result_offs_k[None, :] * stride_hidden_k)\n        _hidden = tl.load(\n            hidden_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_m[:, None] < num_tokens), other=0.0\n        )\n        d_weight = tl.dot(d_logits.to(d_weight_ptr.dtype.element_ty).trans(), _hidden, d_weight)\n\n    # write back\n    tl.store(\n        d_weight_ptr + offs_n[:, None] * stride_d_weight_n + result_offs_k[None, :] * stride_d_weight_k,\n        d_weight,\n        mask=(offs_n[:, None] < vocab_size) & (result_offs_k[None, :] < hidden_size),\n    )\n\n\n# NOTE: split tile from d_logits' perspective\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 16},\n            num_stages=3,\n            num_warps=8,\n        ),\n    ],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_backward_kernel_general_d_logits(\n    num_tokens: int,\n    hidden_size: int,\n    vocab_size: int,\n    rank: int,\n    hidden_ptr,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    weight_ptr,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    labels_ptr,\n    stride_labels: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accu_ptr,\n    stride_accu: tl.int64,\n    d_entropy_ptr,\n    stride_d_entropy: tl.int64,\n    d_logprobs_ptr,\n    stride_d_logprobs: tl.int64,\n    reduction: int,\n    entropy_b_ptr,\n    stride_entropy_b,\n    d_logits_ptr,\n    stride_d_logits_m: tl.int64,\n    stride_d_logits_n: tl.int64,\n    rcp_temperature: tl.float32,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n):\n    \"\"\"\n    backward d_logits\n    \"\"\"\n    # block swizzling\n    # pid = tl.program_id(axis=0)\n    # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    # pid_m = pid % num_pid_m\n    # pid_n = pid // num_pid_m\n\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N)\n    num_pid_in_group = GROUP_SIZE_M * num_pid_n\n    group_id = pid // num_pid_in_group\n    first_pid_m = group_id * GROUP_SIZE_M\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n    pid_n = (pid % num_pid_in_group) // group_size_m\n\n    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n\n    maximum_ptrs = maximum_ptr + offs_am * stride_maximum\n    maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0)\n    accu_ptrs = accu_ptr + offs_am * stride_accu\n    accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6)  # epsilon to avoid division by zero\n    accu_rcp = tl.fdiv(1.0, accu)\n\n    d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy\n    d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0)\n    if reduction == 0:  # none\n        d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs\n        d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0)\n    elif reduction == 1:  # sum\n        d_logprobs = tl.load(d_logprobs_ptr)\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    else:  # mean\n        d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32))\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    d_logprobs = -1 * d_logprobs\n\n    entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b\n    entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0)\n\n    hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n    # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n)\n    weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n    labels_ptrs = labels_ptr + offs_am * stride_labels\n    labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0)\n\n    logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n        _hidden = tl.load(\n            hidden_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n            other=0.0,\n        )\n        # _weight = tl.load(weight_ptrs,\n        #                   mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size),\n        #                   other=0.0)\n        _weight = tl.load(\n            weight_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size),\n            other=0.0,\n        )\n\n        logits = tl.dot(_hidden, _weight.trans(), logits)\n\n        hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n        weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n    hidden_ptrs -= hidden_size * stride_hidden_k\n    weight_ptrs -= hidden_size * stride_weight_k\n\n    # scale logits by temperature\n    logits *= rcp_temperature\n\n    exp_logits = tl.exp(logits - maximum[:, None])\n\n    mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None]\n    d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask)\n    d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None])\n\n    # scale d_logits by temperature\n    d_logits *= rcp_temperature\n\n    # store d_logits\n    d_logits_ptrs = d_logits_ptr + offs_am[:, None] * stride_d_logits_m + offs_bn[None, :] * stride_d_logits_n\n    tl.store(\n        d_logits_ptrs,\n        d_logits,  # will be implicitly converted to d_logits_ptrs.dtype.element_ty\n        mask=(offs_am[:, None] < num_tokens) & (offs_bn[None, :] < vocab_size),\n    )\n\n\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 32, \"GROUP_SIZE_M\": 16},\n            num_stages=3,\n            num_warps=8,\n        ),\n    ],\n    key=[\"num_tokens\", \"hidden_size\", \"vocab_size\"],\n)\n@triton.jit\ndef efficient_entropy_backward_kernel_general_d_logits_split_N(\n    split_idx: int,\n    num_tokens: int,\n    hidden_size: int,\n    vocab_size: int,\n    vocab_per_split: int,\n    rank: int,\n    hidden_ptr,\n    stride_hidden_m: tl.int64,\n    stride_hidden_k: tl.int64,\n    weight_ptr,\n    stride_weight_n: tl.int64,\n    stride_weight_k: tl.int64,\n    labels_ptr,\n    stride_labels: tl.int64,\n    maximum_ptr,\n    stride_maximum: tl.int64,\n    accu_ptr,\n    stride_accu: tl.int64,\n    d_entropy_ptr,\n    stride_d_entropy: tl.int64,\n    d_logprobs_ptr,\n    stride_d_logprobs: tl.int64,\n    reduction: int,\n    entropy_b_ptr,\n    stride_entropy_b,\n    d_logits_ptr,\n    stride_d_logits_m: tl.int64,\n    stride_d_logits_n: tl.int64,\n    rcp_temperature: tl.float32,\n    BLOCK_SIZE_M: tl.constexpr,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    GROUP_SIZE_M: tl.constexpr,\n):\n    pid = tl.program_id(axis=0)\n    num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M)\n    num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N)\n    num_pid_in_group = GROUP_SIZE_M * num_pid_n\n    group_id = pid // num_pid_in_group\n    first_pid_m = group_id * GROUP_SIZE_M\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n    pid_n = (pid % num_pid_in_group) // group_size_m\n\n    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_bn = split_idx * vocab_per_split + pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n\n    maximum = tl.load(maximum_ptr + offs_am * stride_maximum, mask=offs_am < num_tokens, other=0.0)\n    accu = tl.load(accu_ptr + offs_am * stride_accu, mask=offs_am < num_tokens, other=1e-6)\n    accu_rcp = tl.fdiv(1.0, accu)\n    d_entropy = tl.load(d_entropy_ptr + offs_am * stride_d_entropy, mask=offs_am < num_tokens, other=0.0)\n    if reduction == 0:\n        d_logprobs = tl.load(d_logprobs_ptr + offs_am * stride_d_logprobs, mask=offs_am < num_tokens, other=0.0)\n    elif reduction == 1:\n        d_logprobs = tl.load(d_logprobs_ptr)\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    else:\n        d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32))\n        d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,))\n    d_logprobs = -1 * d_logprobs\n    entropy_b = tl.load(entropy_b_ptr + offs_am * stride_entropy_b, mask=offs_am < num_tokens, other=0.0)\n    labels = tl.load(labels_ptr + offs_am * stride_labels, mask=offs_am < num_tokens, other=0)\n\n    hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k)\n    weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k)\n\n    vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size)\n    logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)):\n        _hidden = tl.load(\n            hidden_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens),\n            other=0.0,\n        )\n        _weight = tl.load(\n            weight_ptrs,\n            mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_right_bound),\n            other=0.0,\n        )\n        logits = tl.dot(_hidden, _weight.trans(), logits)\n\n        hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k\n        weight_ptrs += BLOCK_SIZE_K * stride_weight_k\n\n    logits *= rcp_temperature\n    exp_logits = tl.exp(logits - maximum[:, None])\n\n    mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None]\n    d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask)\n    d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None])\n\n    d_logits *= rcp_temperature\n\n    # filter d_logits with mask\n    result_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    mask = (offs_am[:, None] < num_tokens) & (result_offs_n[None, :] < vocab_per_split)\n\n    tl.store(\n        d_logits_ptr + offs_am[:, None] * stride_d_logits_m + result_offs_n[None, :] * stride_d_logits_n, d_logits, mask\n    )\n\n\ndef efficient_entropy_backward(\n    dlogprobs: torch.Tensor,\n    dentropy: torch.Tensor,\n    hidden: torch.Tensor,\n    weight: torch.Tensor,\n    labels: torch.Tensor,\n    maximum: torch.Tensor,\n    acc: torch.Tensor,\n    entropy_b: torch.Tensor,\n    reduction: typing.Optional[int] = 2,\n    should_return_fp32_grad: bool = False,\n    temperature: typing.Optional[float] = 1.0,\n    dist_process_group: typing.Optional[dist.ProcessGroup] = None,\n) -> list[torch.Tensor]:\n    \"\"\"\n    backward host function\n    \"\"\"\n    assert hidden.is_cuda and weight.is_cuda and labels.is_cuda\n    assert weight.device == hidden.device and labels.device == hidden.device\n    assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1\n    assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous()\n    assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1]\n\n    _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group)\n    _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group)\n\n    num_tokens, hidden_size = hidden.shape\n    num_tokens = labels.shape[0]\n    vocab_size, hidden_size = weight.shape\n    assert hidden_size % 128 == 0\n\n    REDUCTION = get_entropy_reduction_enum(reduction)\n\n    if REDUCTION == EntropyReductionEnum._None:\n        assert dlogprobs.shape == (num_tokens,)\n    else:\n        assert dlogprobs.dim() == 0\n\n    assert dlogprobs.is_contiguous() and dentropy.is_contiguous()\n    assert dlogprobs.is_cuda and dentropy.is_cuda\n    assert dlogprobs.device == hidden.device and dlogprobs.device == dentropy.device\n    assert dentropy.shape == (num_tokens,)\n\n    d_hidden, d_weight = None, None\n    if _config._backward == BackwardEnum._Total_Fuse_MN or should_return_fp32_grad:\n        d_hidden = torch.zeros_like(hidden, dtype=torch.float32, device=hidden.device)\n        d_weight = torch.zeros_like(weight, dtype=torch.float32, device=weight.device)\n    else:\n        d_hidden = torch.empty_like(hidden, dtype=hidden.dtype, device=hidden.device)\n        d_weight = torch.empty_like(weight, dtype=hidden.dtype, device=weight.device)\n    assert d_hidden.is_contiguous() and d_weight.is_contiguous()\n\n    assert maximum.is_contiguous() and acc.is_contiguous()\n    assert maximum.device == hidden.device and acc.device == hidden.device\n    assert maximum.shape == labels.shape == acc.shape\n    assert maximum.is_cuda and acc.is_cuda\n\n    vocab_per_split = 1024\n    assert vocab_per_split % 128 == 0\n    num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split\n\n    assert entropy_b.is_contiguous() and entropy_b.is_cuda\n    assert entropy_b.shape == (num_tokens,)\n\n    if _config._backward == BackwardEnum._Total_Fuse_MN:\n        # --- Triton doesn't materialize d_logits at all. Split tiles at the perspective of d_logits.\n        def mainloop_grid(meta):\n            return (triton.cdiv(num_tokens, meta[\"BLOCK_SIZE_M\"]) * triton.cdiv(vocab_size, meta[\"BLOCK_SIZE_N\"]),)\n\n        efficient_entropy_backward_kernel_general_mainloop_MN[mainloop_grid](\n            num_tokens,\n            hidden_size,\n            vocab_size,\n            _rank,\n            hidden,\n            hidden.stride(0),\n            hidden.stride(1),\n            weight,\n            weight.stride(0),\n            weight.stride(1),\n            labels,\n            labels.stride(0),\n            maximum,\n            maximum.stride(0),\n            acc,\n            acc.stride(0),\n            dentropy,\n            dentropy.stride(0),\n            dlogprobs,\n            dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0,\n            REDUCTION,\n            entropy_b,\n            entropy_b.stride(0),\n            d_hidden,\n            d_hidden.stride(0),\n            d_hidden.stride(1),\n            d_weight,\n            d_weight.stride(0),\n            d_weight.stride(1),\n            1.0 / temperature,\n        )\n\n    elif _config._backward == BackwardEnum._Total_Separate:\n        _d_logits = torch.empty((num_tokens, vocab_size), device=hidden.device, dtype=hidden.dtype).contiguous()\n        assert _d_logits.is_contiguous()\n\n        if _config._use_triton:\n\n            def d_logits_grid(meta):\n                return (triton.cdiv(num_tokens, meta[\"BLOCK_SIZE_M\"]) * triton.cdiv(vocab_size, meta[\"BLOCK_SIZE_N\"]),)\n\n            efficient_entropy_backward_kernel_general_d_logits[d_logits_grid](\n                num_tokens,\n                hidden_size,\n                vocab_size,\n                _rank,\n                hidden,\n                hidden.stride(0),\n                hidden.stride(1),\n                weight,\n                weight.stride(0),\n                weight.stride(1),\n                labels,\n                labels.stride(0),\n                maximum,\n                maximum.stride(0),\n                acc,\n                acc.stride(0),\n                dentropy,\n                dentropy.stride(0),\n                dlogprobs,\n                dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0,\n                REDUCTION,\n                entropy_b,\n                entropy_b.stride(0),\n                _d_logits,\n                _d_logits.stride(0),\n                _d_logits.stride(1),\n                1.0 / temperature,\n            )\n\n            torch.matmul(_d_logits, weight, out=d_hidden)\n            torch.matmul(_d_logits.T, hidden, out=d_weight)\n        else:\n            raise AssertionError(\"Triton is required for efficient entropy kernel\")\n\n    elif _config._backward == BackwardEnum._Split_Dlogits_N:\n        vocab_per_split = 9504\n        num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split\n\n        _d_logits = torch.empty((num_tokens, vocab_per_split), device=hidden.device, dtype=hidden.dtype).contiguous()\n        assert _d_logits.is_contiguous()\n\n        def d_logits_grid(meta):\n            return (triton.cdiv(num_tokens, meta[\"BLOCK_SIZE_M\"]) * triton.cdiv(vocab_per_split, meta[\"BLOCK_SIZE_N\"]),)\n\n        for split_idx in range(num_splits):\n            efficient_entropy_backward_kernel_general_d_logits_split_N[d_logits_grid](\n                split_idx,\n                num_tokens,\n                hidden_size,\n                vocab_size,\n                vocab_per_split,\n                _rank,\n                hidden,\n                hidden.stride(0),\n                hidden.stride(1),\n                weight,\n                weight.stride(0),\n                weight.stride(1),\n                labels,\n                labels.stride(0),\n                maximum,\n                maximum.stride(0),\n                acc,\n                acc.stride(0),\n                dentropy,\n                dentropy.stride(0),\n                dlogprobs,\n                dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0,\n                REDUCTION,\n                entropy_b,\n                entropy_b.stride(0),\n                _d_logits,\n                _d_logits.stride(0),\n                _d_logits.stride(1),\n                1.0 / temperature,\n            )\n\n            if split_idx == (num_splits - 1):\n                vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) - split_idx * vocab_per_split\n                _d_logits = _d_logits[:, :vocab_right_bound].contiguous()\n\n            if split_idx == 0:\n                torch.matmul(\n                    _d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :], out=d_hidden\n                )\n            else:\n                d_hidden += torch.matmul(\n                    _d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :]\n                )\n            torch.matmul(\n                _d_logits.T, hidden, out=d_weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :]\n            )\n\n    elif _config._backward == BackwardEnum._Split_Dlogits_M:\n        raise NotImplementedError(\"BackwardEnum._Split_Dlogits_M is not implemented yet\")\n\n    return d_hidden, d_weight\n"
  },
  {
    "path": "siirl/utils/kernel/linear_cross_entropy.py",
    "content": "#\n# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport typing\n\nimport torch\nimport torch.distributed as dist\n\nfrom . import kernels\n\n\nclass LinearCrossEntropy(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        hidden: torch.Tensor,\n        weight: torch.Tensor,\n        labels: torch.Tensor,\n        temperature: typing.Optional[float] = 1.0,\n        reduction: typing.Optional[str] = \"none\",\n        dist_process_group: typing.Optional[dist.ProcessGroup] = None,\n    ) -> list[torch.Tensor]:\n        \"\"\"_summary_\n\n        Args:\n            ctx (_type_): _description_\n            hidden (torch.Tensor): (batch_size, num_tokens, hidden_size) -> (batch_size * num_tokens, hidden_size)\n            weight (torch.Tensor): (vocab_size, hidden_size)\n            labels (torch.Tensor): (batch_size, num_tokens) -> (batch_size * num_tokens, )\n            temperature (typing.Optional[float], optional): _description_. Defaults to 1.0.\n            reduction (typing.Optional[str], optional): _description_. Defaults to \"none\".\n            dist_process_group (typing.Optional[dist.ProcessGroup], optional): _description_. Defaults to None.\n\n        Returns:\n            typing.List[torch.Tensor]: _description_\n        \"\"\"\n\n        assert isinstance(temperature, float), f\"temperature must be a float, but got {type(temperature)}\"\n        assert isinstance(reduction, str), f\"reduction must be a str, but got {type(reduction)}\"\n        with torch.cuda.nvtx.range(\"LinearCrossEntropy-forward\"):\n            REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower())\n\n            original_hidden_shape = hidden.shape\n            if len(hidden.shape) != 2:\n                hidden = hidden.view(-1, hidden.shape[-1])  # (batch_size * num_tokens, hidden_size)\n            if len(labels.shape) != 1:\n                labels = labels.view(-1)\n\n            logprobs, entropy, _maximum, _accumulate, _entropy_b = kernels.efficient_entropy_forward(\n                hidden, weight, labels, REDUCTION, temperature, dist_process_group\n            )\n\n            ctx.save_for_backward(hidden, weight, labels, _maximum, _accumulate, _entropy_b)\n            ctx.original_hidden_shape = original_hidden_shape\n            ctx.REDUCTION = REDUCTION\n            ctx.dist_process_group = dist_process_group\n            ctx.should_return_fp32_grad = False\n            ctx.temperature = temperature\n        return logprobs, entropy\n\n    @staticmethod\n    def backward(ctx, dlogprobs: torch.Tensor, dentropy: torch.Tensor) -> list[torch.Tensor]:\n        with torch.cuda.nvtx.range(\"LinearCrossEntropy-backward\"):\n            (hidden, weight, labels, _maximum, _accumulate, _entropy_b) = ctx.saved_tensors\n            REDUCTION = ctx.REDUCTION\n            dist_process_group = ctx.dist_process_group\n            should_return_fp32_grad = ctx.should_return_fp32_grad\n            temperature = ctx.temperature\n\n            d_hidden, d_weight = kernels.efficient_entropy_backward(\n                dlogprobs,\n                dentropy,\n                hidden,\n                weight,\n                labels,\n                _maximum,\n                _accumulate,\n                _entropy_b,\n                REDUCTION,\n                should_return_fp32_grad,\n                temperature,\n                dist_process_group,\n            )\n            d_hidden = d_hidden.view(ctx.original_hidden_shape)\n\n        return (d_hidden, d_weight, None, None, None, None)\n\n\nlinear_cross_entropy = LinearCrossEntropy.apply\n"
  },
  {
    "path": "siirl/utils/logger/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/utils/logger/aggregate_logger.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA Ray logger will receive logging info from different processes.\n\"\"\"\n\nimport logging\nimport numbers\nfrom typing import Dict\nfrom loguru import logger\n\n\ndef concat_dict_to_str(dict: Dict, step):\n    output = [f\"step:{step}\"]\n    for k, v in dict.items():\n        if isinstance(v, numbers.Number):\n            output.append(f\"{k}:{v:.3f}\")\n    output_str = \" - \".join(output)\n    return output_str\n\n\nclass LocalLogger:\n    def __init__(self, remote_logger=None, enable_wandb=False, print_to_console=False):\n        self.print_to_console = print_to_console\n\n    def flush(self):\n        pass\n\n    def log(self, data, step):\n        if self.print_to_console:\n            logger.info(concat_dict_to_str(data, step=step))\n\n\nclass DecoratorLoggerBase:\n    def __init__(self, role: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0, log_only_rank_0: bool = True):\n        self.role = role\n        self.logger = logger\n        self.level = level\n        self.rank = rank\n        self.log_only_rank_0 = log_only_rank_0\n        self.logging_function = self.log_by_logging\n        if logger is None:\n            self.logging_function = self.log_by_print\n\n    def log_by_print(self, log_str):\n        if not self.log_only_rank_0 or self.rank == 0:\n            print(f\"{self.role} {log_str}\", flush=True)\n\n    def log_by_logging(self, log_str):\n        if self.logger is None:\n            raise ValueError(\"Logger is not initialized\")\n        if not self.log_only_rank_0 or self.rank == 0:\n            self.logger.log(self.level, f\"{self.role} {log_str}\")\n\n\ndef log_with_rank(message: str, rank, logger: logging.Logger, level=logging.INFO, log_only_rank_0: bool = False):\n    \"\"\"_summary_\n    Log a message with rank information using a logger.\n    This function logs the message only if `log_only_rank_0` is False or if the rank is 0.\n    Args:\n        message (str): The message to log.\n        rank (int): The rank of the process.\n        logger (logging.Logger): The logger instance to use for logging.\n        level (int, optional): The logging level. Defaults to logging.INFO.\n        log_only_rank_0 (bool, optional): If True, only log for rank 0. Defaults to False.\n    \"\"\"\n    if not log_only_rank_0 or rank == 0:\n        logger.log(level, f\"[Rank {rank}] {message}\")"
  },
  {
    "path": "siirl/utils/logger/logging_utils.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom loguru import logger\nimport os\nimport sys\n\nSIIRL_LOG_DIRCTORY = os.getenv(\"SIIRL_LOG_DIRECTORY\", \"siirl_logs\")\n\nSIIRL_LOGGING_FILENAME = os.getenv(\"SIIRL_LOGGING_FILENAME\", \"siirl\")\nif SIIRL_LOGGING_FILENAME != \"\":\n    SIIRL_LOGGING_FILENAME += \"_\"\n\n\ndef set_basic_config():\n    \"\"\"\n    This function sets the global logging format and level. It will be called when import siirl\n    \"\"\"\n\n    log_level = os.environ.get(\"LOGURU_LEVEL\", \"INFO\")\n\n    logger.remove()\n    logger.level(\"CRITICAL\", color=\"<bold white on red>\")\n    # logger.level(\"ERROR\", color=\"<red><bold>\")\n    # logger.level(\"WARNING\", color=\"<yellow>\")\n    logger.add(\n        sys.stderr,\n        level=log_level,\n        format=(\n            \"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | \"\n            \"<level>{level.icon}</level> \"\n            \"<level>{level: <8}</level> | \"\n            # Added :<magenta>{line}</magenta> to include the line number\n            \"<blue>{file}</blue>:<magenta>{line}</magenta>:<cyan>{function}</cyan> >> \"\n            \"<level>{message}</level>\"\n        ),\n        enqueue=True,\n        colorize=True,\n    )\n    os.makedirs(SIIRL_LOG_DIRCTORY, exist_ok=True)\n\n    logger.add(\n        sink=os.path.join(SIIRL_LOG_DIRCTORY, SIIRL_LOGGING_FILENAME + \"{time:YYYY-MM-DD-HH}.log\"),\n        level=log_level,\n        rotation=\"500 MB\",\n        retention=\"30 days\",\n        format=\"{time:YYYY-MM-DD HH:mm:ss.SSS} | {level} | {module}:{function}:{line} >> {message}\",\n        compression=\"zip\",\n        encoding=\"utf-8\",\n        enqueue=True,\n    )\n"
  },
  {
    "path": "siirl/utils/logger/tracking.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA unified tracking interface that supports logging data to different backend\n\"\"\"\n\nimport dataclasses\nfrom enum import Enum\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import List, Union, Dict, Any\nfrom loguru import logger\nimport os\n\n\nclass Tracking:\n    \"\"\"A unified tracking interface for logging experiment data to multiple backends.\n\n    This class provides a centralized way to log experiment metrics, parameters, and artifacts\n    to various tracking backends including WandB, MLflow, SwanLab, TensorBoard, and console.\n\n    Attributes:\n        supported_backend: List of supported tracking backends.\n        logger: Dictionary of initialized logger instances for each backend.\n    \"\"\"\n\n    supported_backend = [\"wandb\", \"mlflow\", \"swanlab\", \"vemlp_wandb\", \"tensorboard\", \"console\", \"clearml\"]\n\n    def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = \"console\", config=None):\n        if isinstance(default_backend, str):\n            default_backend = [default_backend]\n        for backend in default_backend:\n            if backend == \"tracking\":\n                import warnings\n\n                warnings.warn(\"`tracking` logger is deprecated. use `wandb` instead.\", DeprecationWarning, stacklevel=2)\n            else:\n                assert backend in self.supported_backend, f\"{backend} is not supported\"\n\n        self.logger = {}\n\n        if \"tracking\" in default_backend or \"wandb\" in default_backend:\n            import wandb\n\n            settings = None\n            if config[\"trainer\"].get(\"wandb_proxy\", None):\n                settings = wandb.Settings(https_proxy=config[\"trainer\"][\"wandb_proxy\"])\n            wandb.init(project=project_name, name=experiment_name, config=config, settings=settings)\n            self.logger[\"wandb\"] = wandb\n\n        if \"mlflow\" in default_backend:\n            import os\n\n            import mlflow\n\n            MLFLOW_TRACKING_URI = os.environ.get(\"MLFLOW_TRACKING_URI\", None)\n            if MLFLOW_TRACKING_URI:\n                mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)\n\n            # Project_name is actually experiment_name in MLFlow\n            # If experiment does not exist, will create a new experiment\n            experiment = mlflow.set_experiment(project_name)\n            mlflow.start_run(experiment_id=experiment.experiment_id, run_name=experiment_name)\n            mlflow.log_params(_compute_mlflow_params_from_objects(config))\n            self.logger[\"mlflow\"] = _MlflowLoggingAdapter()\n\n        if \"swanlab\" in default_backend:\n            import os\n\n            import swanlab\n\n            SWANLAB_API_KEY = os.environ.get(\"SWANLAB_API_KEY\", None)\n            SWANLAB_LOG_DIR = os.environ.get(\"SWANLAB_LOG_DIR\", \"swanlog\")\n            SWANLAB_MODE = os.environ.get(\"SWANLAB_MODE\", \"cloud\")\n            if SWANLAB_API_KEY:\n                swanlab.login(SWANLAB_API_KEY)  # NOTE: previous login information will be overwritten\n\n            if config is None:\n                config = {}  # make sure config is not None, otherwise **config will raise error\n            swanlab.init(\n                project=project_name,\n                experiment_name=experiment_name,\n                config={\"FRAMEWORK\": \"siirl\", **config},\n                logdir=SWANLAB_LOG_DIR,\n                mode=SWANLAB_MODE,\n            )\n            self.logger[\"swanlab\"] = swanlab\n\n        if \"vemlp_wandb\" in default_backend:\n            import os\n\n            import volcengine_ml_platform\n            from volcengine_ml_platform import wandb as vemlp_wandb\n\n            volcengine_ml_platform.init(\n                ak=os.environ[\"VOLC_ACCESS_KEY_ID\"],\n                sk=os.environ[\"VOLC_SECRET_ACCESS_KEY\"],\n                region=os.environ[\"MLP_TRACKING_REGION\"],\n            )\n\n            vemlp_wandb.init(\n                project=project_name,\n                name=experiment_name,\n                config=config,\n                sync_tensorboard=True,\n            )\n            self.logger[\"vemlp_wandb\"] = vemlp_wandb\n\n        if \"tensorboard\" in default_backend:\n            self.logger[\"tensorboard\"] = _TensorboardAdapter()\n\n        if \"console\" in default_backend:\n            from siirl.utils.logger.aggregate_logger import LocalLogger\n\n            self.console_logger = LocalLogger(print_to_console=True)\n            self.logger[\"console\"] = self.console_logger\n\n        if \"clearml\" in default_backend:\n            self.logger[\"clearml\"] = ClearMLLogger(project_name, experiment_name, config)\n\n    def log(self, data, step, backend=None):\n        for default_backend, logger_instance in self.logger.items():\n            if backend is None or default_backend in backend:\n                logger_instance.log(data=data, step=step)\n\n    def __del__(self):\n        if \"wandb\" in self.logger:\n            self.logger[\"wandb\"].finish(exit_code=0)\n        if \"swanlab\" in self.logger:\n            self.logger[\"swanlab\"].finish()\n        if \"vemlp_wandb\" in self.logger:\n            self.logger[\"vemlp_wandb\"].finish(exit_code=0)\n        if \"tensorboard\" in self.logger:\n            self.logger[\"tensorboard\"].finish()\n\n        if \"clearnml\" in self.logger:\n            self.logger[\"clearnml\"].finish()\n\n\nclass ClearMLLogger:\n    def __init__(self, project_name: str, experiment_name: str, config):\n        self.project_name = project_name\n        self.experiment_name = experiment_name\n\n        import clearml\n\n        self._task: clearml.Task = clearml.Task.init(\n            task_name=experiment_name,\n            project_name=project_name,\n            continue_last_task=True,\n            output_uri=False,\n        )\n\n        self._task.connect_configuration(config, name=\"Hyperparameters\")\n\n    def _get_logger(self):\n        return self._task.get_logger()\n\n    def log(self, data, step):\n        import numpy as np\n        import pandas as pd\n\n        # logs = self._rewrite_logs(data)\n        logger = self._get_logger()\n        for k, v in data.items():\n            title, series = k.split(\"/\", 1)\n\n            if isinstance(v, (int, float, np.floating, np.integer)):\n                logger.report_scalar(\n                    title=title,\n                    series=series,\n                    value=v,\n                    iteration=step,\n                )\n            elif isinstance(v, pd.DataFrame):\n                logger.report_table(\n                    title=title,\n                    series=series,\n                    table_plot=v,\n                    iteration=step,\n                )\n            else:\n                logger.warning(f'Trainer is attempting to log a value of \"{v}\" of type {type(v)} for key \"{k}\". This invocation of ClearML logger\\'s function is incorrect so this attribute was dropped. ')\n\n    def finish(self):\n        self._task.mark_completed()\n\n\nclass _TensorboardAdapter:\n    def __init__(self):\n        import os\n\n        from torch.utils.tensorboard import SummaryWriter\n\n        tensorboard_dir = os.environ.get(\"TENSORBOARD_DIR\", \"tensorboard_log\")\n        os.makedirs(tensorboard_dir, exist_ok=True)\n        logger.info(f\"Saving tensorboard log to {tensorboard_dir}.\")\n        self.writer = SummaryWriter(tensorboard_dir)\n\n    def log(self, data, step):\n        for key in data:\n            self.writer.add_scalar(key, data[key], step)\n\n    def finish(self):\n        self.writer.close()\n\n\nclass _MlflowLoggingAdapter:\n    def log(self, data, step):\n        import mlflow\n\n        results = {k.replace(\"@\", \"_at_\"): v for k, v in data.items()}\n        mlflow.log_metrics(metrics=results, step=step)\n\n\ndef _compute_mlflow_params_from_objects(params) -> Dict[str, Any]:\n    if params is None:\n        return {}\n\n    return _flatten_dict(_transform_params_to_json_serializable(params, convert_list_to_dict=True), sep=\"/\")\n\n\ndef _transform_params_to_json_serializable(x, convert_list_to_dict: bool):\n    _transform = partial(_transform_params_to_json_serializable, convert_list_to_dict=convert_list_to_dict)\n\n    if dataclasses.is_dataclass(x):\n        return _transform(dataclasses.asdict(x))\n    if isinstance(x, dict):\n        return {k: _transform(v) for k, v in x.items()}\n    if isinstance(x, list):\n        if convert_list_to_dict:\n            return {\"list_len\": len(x)} | {f\"{i}\": _transform(v) for i, v in enumerate(x)}\n        else:\n            return [_transform(v) for v in x]\n    if isinstance(x, Path):\n        return str(x)\n    if isinstance(x, Enum):\n        return x.value\n\n    return x\n\n\ndef _flatten_dict(raw: Dict[str, Any], *, sep: str) -> Dict[str, Any]:\n    import pandas as pd\n\n    ans = pd.json_normalize(raw, sep=sep).to_dict(orient=\"records\")[0]\n    assert isinstance(ans, dict)\n    return ans\n\n\n@dataclasses.dataclass\nclass ValidationGenerationsLogger:\n    def log(self, loggers, samples, step):\n        if \"wandb\" in loggers:\n            self.log_generations_to_wandb(samples, step)\n        if \"swanlab\" in loggers:\n            self.log_generations_to_swanlab(samples, step)\n        if \"mlflow\" in loggers:\n            self.log_generations_to_mlflow(samples, step)\n\n        if \"clearml\" in loggers:\n            self.log_generations_to_clearml(samples, step)\n        if \"tensorboard\" in loggers:\n            self.log_generations_to_tensorboard(samples, step)\n\n    def log_generations_to_wandb(self, samples, step):\n        \"\"\"Log samples to wandb as a table\"\"\"\n        import wandb\n\n        # Create column names for all samples\n        columns = [\"step\"] + sum([[f\"input_{i + 1}\", f\"output_{i + 1}\", f\"score_{i + 1}\"] for i in range(len(samples))], [])\n\n        if not hasattr(self, \"validation_table\"):\n            # Initialize the table on first call\n            self.validation_table = wandb.Table(columns=columns)\n\n        # Create a new table with same columns and existing data\n        # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737\n        new_table = wandb.Table(columns=columns, data=self.validation_table.data)\n\n        # Add new row with all data\n        row_data = []\n        row_data.append(step)\n        for sample in samples:\n            row_data.extend(sample)\n\n        new_table.add_data(*row_data)\n\n        # Update reference and log\n        wandb.log({\"val/generations\": new_table}, step=step)\n        self.validation_table = new_table\n\n    def log_generations_to_swanlab(self, samples, step):\n        \"\"\"Log samples to swanlab as text\"\"\"\n        import swanlab\n\n        swanlab_text_list = []\n        for i, sample in enumerate(samples):\n            row_text = f\"\"\"\n            input: {sample[0]}\n            \n            ---\n            \n            output: {sample[1]}\n            \n            ---\n            \n            score: {sample[2]}\n            \"\"\"\n            swanlab_text_list.append(swanlab.Text(row_text, caption=f\"sample {i + 1}\"))\n\n        # Log to swanlab\n        swanlab.log({\"val/generations\": swanlab_text_list}, step=step)\n\n    def log_generations_to_mlflow(self, samples, step):\n        \"\"\"Log validation generation to mlflow as artifacts\"\"\"\n        # https://mlflow.org/docs/latest/api_reference/python_api/mlflow.html?highlight=log_artifact#mlflow.log_artifact\n\n        import json\n        import tempfile\n\n        import mlflow\n\n        try:\n            with tempfile.TemporaryDirectory() as tmp_dir:\n                validation_gen_step_file = Path(tmp_dir, f\"val_step{step}.json\")\n                row_data = []\n                for sample in samples:\n                    data = {\"input\": sample[0], \"output\": sample[1], \"score\": sample[2]}\n                    row_data.append(data)\n                with open(validation_gen_step_file, \"w\") as file:\n                    json.dump(row_data, file)\n                mlflow.log_artifact(validation_gen_step_file)\n        except Exception as e:\n            logger.warning(f\"save validation generation file to mlflow failed with error {e}\")\n\n    def log_generations_to_clearml(self, samples, step):\n        \"\"\"Log validation generation to clearml as table\"\"\"\n\n        import clearml\n        import pandas as pd\n\n        task: clearml.Task | None = clearml.Task.current_task()\n        if task is None:\n            return\n\n        table = [\n            {\n                \"step\": step,\n                \"input\": sample[0],\n                \"output\": sample[1],\n                \"score\": sample[2],\n            }\n            for sample in samples\n        ]\n\n        logger = task.get_logger()\n        logger.report_table(\n            series=\"Validation generations\",\n            title=\"Validation\",\n            table_plot=pd.DataFrame.from_records(table),\n            iteration=step,\n        )\n\n    def log_generations_to_tensorboard(self, samples, step):\n        \"\"\"Log samples to tensorboard as text\"\"\"\n        # Initialize tensorboard writer if not exists\n        if not hasattr(self, \"writer\"):\n            from torch.utils.tensorboard import SummaryWriter\n\n            tensorboard_dir = os.environ.get(\"TENSORBOARD_DIR\", \"tensorboard_log\")\n            os.makedirs(tensorboard_dir, exist_ok=True)\n            self.writer = SummaryWriter(log_dir=tensorboard_dir)\n\n        # Format the samples data into readable text\n        text_content = f\"**Generation Results - Step {step}**\\n\\n\"\n\n        for i, sample in enumerate(samples):\n            text_content += f\"### Sample {i + 1}\\n\"\n\n            # Assuming sample contains [input, output, score]\n            if len(sample) >= 3:\n                input_text, output_text, score = sample[0], sample[1], sample[2]\n\n                text_content += f\"**Input:** {input_text}\\n\\n\"\n                text_content += f\"**Output:** {output_text}\\n\\n\"\n                text_content += f\"**Score:** {score}\\n\\n\"\n            else:\n                # Handle cases where sample format might be different\n                text_content += f\"**Data:** {sample}\\n\\n\"\n\n            text_content += \"---\\n\\n\"\n\n        # Log to tensorboard as text\n        self.writer.add_text(\"val/generations\", text_content, step)\n        # Flush to ensure data is written\n        self.writer.flush()\n"
  },
  {
    "path": "siirl/utils/megatron/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/utils/megatron/dist_checkpointing.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numpy as np\n\n# Monkey patch np.product to np.prod for compatibility with newer numpy versions\nif not hasattr(np, \"product\"):\n    np.product = np.prod\n\nfrom megatron.core import dist_checkpointing, mpu\nfrom megatron.core.dist_checkpointing.serialization import (\n    get_default_load_sharded_strategy,\n    get_default_save_sharded_strategy,\n)\nfrom megatron.core.dist_checkpointing.strategies.fully_parallel import (\n    FullyParallelLoadStrategyWrapper,\n    FullyParallelSaveStrategyWrapper,\n)\n\n\ndef save_dist_checkpointing(sharded_state_dict, ckpt_path, async_save=False):\n    validate_sharding_integrity = True\n    # Get checkpointing strategies\n    save_strategy = get_default_save_sharded_strategy(\"torch_dist\")\n    save_strategy = FullyParallelSaveStrategyWrapper(\n        save_strategy, mpu.get_data_parallel_group(with_context_parallel=True)\n    )\n\n    # Save model sharded state dicts\n    async_save_request = dist_checkpointing.save(\n        sharded_state_dict,\n        ckpt_path,\n        sharded_strategy=save_strategy,\n        async_sharded_save=async_save,\n        validate_access_integrity=validate_sharding_integrity,\n    )\n\n    return async_save_request\n\n\ndef load_dist_checkpointing(sharded_state_dict, ckpt_dir):\n    # Get checkpointing strategies\n    load_strategy = get_default_load_sharded_strategy(ckpt_dir)\n    load_strategy = FullyParallelLoadStrategyWrapper(\n        load_strategy, mpu.get_data_parallel_group(with_context_parallel=True)\n    )\n\n    # Load model sharded state dicts\n    state_dict = dist_checkpointing.load(sharded_state_dict, ckpt_dir, sharded_strategy=load_strategy)\n\n    return state_dict\n"
  },
  {
    "path": "siirl/utils/megatron/megatron_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n# Copyright 2023-2024 SGLang Team\n# Copyright 2025 ModelBest Inc. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Pretrain utilities.\"\"\"\n\nimport gc\nimport os\nimport warnings\nfrom typing import Any, Dict\nimport inspect\nfrom dataclasses import dataclass\n\nimport torch\nimport torch.nn.functional as F\nfrom megatron.core import ModelParallelConfig, mpu, tensor_parallel\nfrom megatron.core.distributed import DistributedDataParallel as DDP\nfrom megatron.core.distributed import DistributedDataParallelConfig\nfrom megatron.core.enums import ModelType\nfrom megatron.core.optimizer import ChainedOptimizer, OptimizerConfig\nfrom megatron.core.transformer import TransformerConfig\nfrom megatron.core.transformer.module import Float16Module\nfrom megatron.core.utils import get_attr_wrapped_model\nfrom transformers import PretrainedConfig\n\nimport siirl.utils.megatron.tensor_parallel as tp_utils\nfrom siirl.utils.extras.device import get_device_id, get_device_name, get_torch_device\nfrom siirl.utils.model_utils.model import normalize_model_name\nfrom siirl.utils.model_utils.torch_dtypes import PrecisionType\nfrom siirl.utils.extras.fs import local_mkdir_safe\n\n\ndef get_model_config(model):\n    return get_attr_wrapped_model(model, \"config\", allow_none=False)\n\n\ndef get_model(\n    model_provider_func,\n    model_type=ModelType.encoder_or_decoder,\n    wrap_with_ddp=True,\n    use_distributed_optimizer=True,\n    transformer_config=None,\n    override_ddp_config=None,\n):\n    \"\"\"Build the model.\"\"\"\n    # Build model.\n    if (\n        mpu.get_pipeline_model_parallel_world_size() > 1\n        and mpu.get_virtual_pipeline_model_parallel_world_size() is not None\n    ):\n        assert model_type != ModelType.encoder_and_decoder, (\n            \"Interleaved schedule not supported for model with both encoder and decoder\"\n        )\n        model = []\n        has_vp_stage = inspect.signature(mpu.is_pipeline_first_stage).parameters.get(\"vp_stage\", None) is not None\n        for i in range(mpu.get_virtual_pipeline_model_parallel_world_size()):\n            mpu.set_virtual_pipeline_model_parallel_rank(i)\n            # Set pre_process and post_process only after virtual rank is set.\n            extra_kwargs = {} if not has_vp_stage else {\"ignore_virtual\": False, \"vp_stage\": i}\n            pre_process = mpu.is_pipeline_first_stage(**extra_kwargs)\n            post_process = mpu.is_pipeline_last_stage(**extra_kwargs)\n            this_model = model_provider_func(pre_process=pre_process, post_process=post_process, vp_stage=i)\n            this_model.model_type = model_type\n            model.append(this_model)\n        mpu.set_virtual_pipeline_model_parallel_rank(0)\n    else:\n        pre_process = mpu.is_pipeline_first_stage()\n        post_process = mpu.is_pipeline_last_stage()\n        add_encoder = True\n        add_decoder = True\n        if model_type == ModelType.encoder_and_decoder:\n            if mpu.get_pipeline_model_parallel_world_size() > 1:\n                assert mpu.get_pipeline_model_parallel_split_rank() is not None, (\n                    \"Split rank needs to be specified for model with both encoder and decoder\"\n                )\n                rank = mpu.get_pipeline_model_parallel_rank()\n                split_rank = mpu.get_pipeline_model_parallel_split_rank()\n                world_size = mpu.get_pipeline_model_parallel_world_size()\n                pre_process = rank == 0 or rank == split_rank\n                post_process = (rank == (split_rank - 1)) or (rank == (world_size - 1))\n                add_encoder = mpu.is_pipeline_stage_before_split()\n                add_decoder = mpu.is_pipeline_stage_after_split()\n            model = model_provider_func(\n                pre_process=pre_process, post_process=post_process, add_encoder=add_encoder, add_decoder=add_decoder\n            )\n        else:\n            model = model_provider_func(pre_process=pre_process, post_process=post_process)\n        model.model_type = model_type\n\n    if not isinstance(model, list):\n        model = [model]\n\n    # Set tensor model parallel attributes if not set.\n    # Only parameters that are already tensor model parallel have these\n    # attributes set for them. We should make sure the default attributes\n    # are set for all params so the optimizer can use them.\n    for model_module in model:\n        for param in model_module.parameters():\n            tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)\n\n    # Print number of parameters.\n    if mpu.get_data_parallel_rank() == 0:\n        print(\n            \" > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}\".format(\n                mpu.get_tensor_model_parallel_rank(),\n                mpu.get_pipeline_model_parallel_rank(),\n                sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]),\n            ),\n            flush=True,\n        )\n\n    # GPU allocation.\n    if transformer_config is None or (not transformer_config.use_cpu_initialization):\n        for model_module in model:\n            model_module.to(f\"{get_device_name()}:{get_device_id()}\")\n\n    # Fp16 conversion.\n    config: TransformerConfig = get_model_config(model[0])\n    config.fp8 = None\n    tfconfig: TransformerConfig = model[0].config\n    if config.fp16 or config.bf16:  # the ModelParallelConfig in GPTModel\n        model = [Float16Module(config, model_module) for model_module in model]\n\n    if wrap_with_ddp:\n        ddp_models = []\n        ddp_config_dict = {\n            \"use_distributed_optimizer\": use_distributed_optimizer,\n            \"grad_reduce_in_fp32\": True,\n            \"overlap_grad_reduce\": False,\n        }\n        if override_ddp_config is not None:\n            ddp_config_dict.update(override_ddp_config)\n        ddp_config = DistributedDataParallelConfig(**ddp_config_dict)\n        for model_chunk_idx, model_chunk in enumerate(model):\n            ddp_model = DDP(\n                config=tfconfig,\n                module=model_chunk,\n                disable_bucketing=(model_chunk_idx > 0),\n                ddp_config=ddp_config,\n            )\n            ddp_models.append(ddp_model)\n        model = ddp_models\n        # # Broadcast params from data parallel src rank to other data parallel ranks.\n        # # if args.data_parallel_random_init:\n        for model_module in model:\n            model_module.broadcast_params()\n    return model\n\n@dataclass\nclass McoreModuleWrapperConfig:\n    \"\"\"Configuration for Mcore module wrapper.\"\"\"\n\n    is_value_model: bool = False\n    share_embeddings_and_output_weights: bool = False\n    wrap_with_ddp: bool = True\n    use_distributed_optimizer: bool = True\n\n\ndef make_megatron_module(\n    wrap_config: McoreModuleWrapperConfig,\n    tf_config: TransformerConfig,\n    hf_config: PretrainedConfig,\n    bridge: Any = None,\n    override_model_config: dict[str, Any] = None,\n    override_ddp_config: dict[str, Any] = None,\n):\n    if override_model_config is None:\n        override_model_config = {}\n\n    if bridge is not None:\n        from siirl.models.mcore.mbridge import freeze_moe_router, make_value_model\n\n        post_model_creation_callbacks = []\n        if wrap_config.is_value_model:\n            post_model_creation_callbacks.append(make_value_model)\n        if override_model_config.get(\"moe_config\", {}).get(\"freeze_moe_router\", False):\n            post_model_creation_callbacks.append(freeze_moe_router)\n        return bridge.get_model(\n            post_model_creation_callbacks=post_model_creation_callbacks,\n            wrap_with_ddp=wrap_config.wrap_with_ddp,\n        )\n    else:\n\n        def megatron_model_provider(pre_process, post_process, vp_stage=None):\n            from siirl.models.mcore import init_mcore_model\n\n            parallel_model = init_mcore_model(\n                tf_config,\n                hf_config,\n                pre_process,\n                post_process,\n                share_embeddings_and_output_weights=wrap_config.share_embeddings_and_output_weights,\n                value=wrap_config.is_value_model,\n                freeze_moe_router=override_model_config.get(\"moe_config\", {}).get(\"freeze_moe_router\", False),\n                vp_stage=vp_stage,\n            )\n            parallel_model.to(get_device_name())\n            return parallel_model\n\n        return get_model(\n            megatron_model_provider,\n            wrap_with_ddp=wrap_config.wrap_with_ddp,\n            use_distributed_optimizer=wrap_config.use_distributed_optimizer,\n            override_ddp_config=override_ddp_config,\n        )\n\nALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module)\n\n\ndef unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):\n    return_list = True\n    if not isinstance(model, list):\n        model = [model]\n        return_list = False\n    unwrapped_model = []\n    for model_module in model:\n        while isinstance(model_module, module_instances):\n            model_module = model_module.module\n        unwrapped_model.append(model_module)\n    if not return_list:\n        return unwrapped_model[0]\n    return unwrapped_model\n\n\ndef convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerConfig:\n    print(f\"megatron config {megatron_config}\")\n    dt = PrecisionType.to_dtype(megatron_config.params_dtype)\n    print(f\"pipeline_dtype=megatron_config {dt}\")\n    qkv_bias = True if \"Qwen2ForCausalLM\" in hf_config.architectures else getattr(hf_config, \"attention_bias\", False)\n    overlap_p2p_comm = mpu.get_virtual_pipeline_model_parallel_world_size() is not None and mpu.get_virtual_pipeline_model_parallel_world_size() > 1\n    batch_p2p_comm = False\n    transformer_config = TransformerConfig(\n        num_layers=hf_config.num_hidden_layers,\n        hidden_size=hf_config.hidden_size,\n        num_attention_heads=hf_config.num_attention_heads,\n        num_query_groups=hf_config.num_key_value_heads,\n        ffn_hidden_size=hf_config.intermediate_size,\n        #    max_position_embeddings=hf_config.max_position_embeddings,\n        activation_func=F.silu,\n        normalization=\"RMSNorm\",\n        #    rotary_percent=False, # default,\n        gated_linear_unit=True,  # for llama\n        use_cpu_initialization=True,\n        apply_residual_connection_post_layernorm=False,  # check what's this mean\n        add_bias_linear=False,\n        tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(),\n        pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(),\n        virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(),\n        context_parallel_size=mpu.get_context_parallel_world_size(),\n        overlap_p2p_comm=overlap_p2p_comm,\n        batch_p2p_comm=batch_p2p_comm,\n        pipeline_dtype=dt,\n        params_dtype=dt,\n        sequence_parallel=mpu.get_tensor_model_parallel_world_size() > 1,\n        variable_seq_lengths=True,\n        masked_softmax_fusion=True,\n        moe_token_dispatcher_type=\"alltoall\",\n        attention_dropout=hf_config.attention_dropout,\n        hidden_dropout=getattr(hf_config, \"hidden_dropout\", 0.0),\n        add_qkv_bias=qkv_bias,\n        bf16=dt is torch.bfloat16,\n    )\n\n    return transformer_config\n\n\ndef init_megatron_optim_config(optim_config: Dict) -> OptimizerConfig:\n    optim_args = {\n        \"optimizer\": \"adam\",\n        \"lr\": optim_config.lr,\n        \"min_lr\": optim_config.min_lr,\n        \"clip_grad\": optim_config.clip_grad,\n        \"weight_decay\": optim_config.weight_decay,\n        \"bf16\": True,\n        \"params_dtype\": torch.bfloat16,\n        \"use_distributed_optimizer\": True,\n    }\n\n    override_config = optim_config.override_optimizer_config\n    if override_config:\n        for k, v in override_config.items():\n            optim_args[k] = v\n\n    print_rank_0(f\"optimizer config after override: {optim_args}\")\n\n    config = OptimizerConfig(**optim_args)\n    return config\n\n\ndef mcore_model_parallel_config(\n    sequence_parallel: bool,\n    params_dtype: torch.dtype,\n) -> ModelParallelConfig:\n    # WARNING: Code should not reach this point. This function is deprecated and will be removed.\n    # Please use hf_to_mcore_config_dense() from siirl.models.mcore.config_converter instead.\n    warnings.warn(\n        \"Code should not reach this point. This function is deprecated and will be removed. Please use hf_to_mcore_config_dense() from siirl.models.mcore.config_converter instead.\",\n        DeprecationWarning,\n        stacklevel=2,\n    )\n    return ModelParallelConfig(\n        tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(),\n        pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(),\n        virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(),\n        context_parallel_size=mpu.get_context_parallel_world_size(),\n        sequence_parallel=sequence_parallel,\n        params_dtype=params_dtype,\n        pipeline_dtype=params_dtype,\n        bf16=True,\n        fp16=False,\n        timers=None,\n    )\n\n\n@torch.no_grad()\ndef offload_megatron_model_to_cpu(models):\n    \"\"\"\n    In megatron, the model and optimizer storage are:\n    - bf16 parameter data chunked in model parallel group\n    - fp32 grad chunked in model parallel group\n    - fp32 main_parameter chunked in model and dp group\n    - fp32 optimizer state chunked in model and dp group\n    \"\"\"\n    for model_chunk in models:\n        if isinstance(model_chunk, DDP):\n            model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers]\n            for buffers in model_chunk_all_buffers:\n                for buffer in buffers:\n                    # offload parameters\n                    if buffer.param_data.storage().size() > 0:\n                        buffer.param_data.cpu_data = buffer.param_data.data.cpu().pin_memory()\n                        buffer.param_data_size = buffer.param_data.storage().size()\n                        buffer.param_data.storage().resize_(0)\n\n                    assert buffer.param_data_size == buffer.param_data.cpu_data.storage().size()\n\n                    if buffer.grad_data.storage().size() > 0:\n                        # if the grad_data size is already zero, we assume that it is already offloaded\n                        buffer.grad_data_size = buffer.grad_data.storage().size()\n                        buffer.grad_data.storage().resize_(0)\n        else:\n            # we need this for ref module\n            for _, param in model_chunk.named_parameters():\n                param.data = param.data.to(\"cpu\", non_blocking=True)\n                if param.grad is not None:\n                    param.grad = param.grad.to(\"cpu\", non_blocking=True)\n    gc.collect()\n    get_torch_device().empty_cache()\n\n\n@torch.no_grad()\ndef load_megatron_model_to_gpu(models, load_grad=True):\n    for model_chunk in models:\n        if isinstance(model_chunk, DDP):\n            model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers]\n            for buffers in model_chunk_all_buffers:\n                for buffer in buffers:\n                    # sometimes, we don't want to load grad for pure inference\n                    if load_grad:\n                        buffer.grad_data.storage().resize_(buffer.grad_data_size)\n                        buffer.grad_data.zero_()\n\n                    if buffer.param_data.storage().size() == 0:\n                        buffer.param_data.storage().resize_(buffer.param_data_size)\n                        # copy data from cpu to cuda\n                        buffer.param_data.copy_(buffer.param_data.cpu_data, non_blocking=True)\n                        # gc might free the cpu_data, but we manually do this for sure to avoid potential memory leak\n                        del buffer.param_data.cpu_data\n        else:\n            # we need this for ref module\n            device_id = get_device_id()\n            for _, param in model_chunk.named_parameters():\n                param.data = param.data.to(device_id, non_blocking=True)\n                if param.grad is not None:\n                    param.grad = param.grad.to(device_id, non_blocking=True)\n    gc.collect()\n    get_torch_device().empty_cache()\n\n\n@torch.no_grad()\ndef offload_megatron_copy_params(optimizers):\n    \"\"\"\n    Offload optimizer parameters to CPU. Supports both Megatron optimizers\n    and `ChainedOptimizer`, which wraps a list of underlying optimizers.\n\n    Args:\n        optimizers: The optimizer or ChainedOptimizer instance.\n    \"\"\"\n\n    def _iter_opts(opt):\n        if isinstance(opt, ChainedOptimizer):\n            return opt.chained_optimizers\n        return [opt]\n\n    def offload_tensor_to_cpu(tensor):\n        if tensor is None:\n            return\n        tensor.data = tensor.data.to(\"cpu\", non_blocking=True)\n\n    def offload_group_to_cpu(group):\n        if group is None:\n            return\n\n        if isinstance(group, list):\n            for param_group in group:\n                if isinstance(param_group, list):\n                    for param in param_group:\n                        offload_tensor_to_cpu(param)\n                else:\n                    offload_tensor_to_cpu(param_group)\n        else:\n            offload_tensor_to_cpu(group)\n\n    # Offload all parameter groups to CPU for each underlying optimizer\n\n    for _opt in _iter_opts(optimizers):\n        if hasattr(_opt, \"shard_fp32_from_float16_groups\"):\n            offload_group_to_cpu(_opt.shard_fp32_from_float16_groups)\n\n\n@torch.no_grad()\ndef load_megatron_copy_params(optimizers):\n    \"\"\"\n    Load optimizer parameters back to GPU. Handles ChainedOptimizer.\n\n    Args:\n        optimizers: Optimizer or ChainedOptimizer instance.\n    \"\"\"\n\n    def _iter_opts(opt):\n        if isinstance(opt, ChainedOptimizer):\n            return opt.chained_optimizers\n        return [opt]\n\n    def load_tensor_to_gpu(tensor):\n        if tensor is None:\n            return\n        device_id = get_device_id()\n        tensor.data = tensor.data.to(device_id, non_blocking=True)\n\n    def load_group_to_gpu(group):\n        if group is None:\n            return\n\n        if isinstance(group, list):\n            for param_group in group:\n                if isinstance(param_group, list):\n                    for param in param_group:\n                        load_tensor_to_gpu(param)\n                else:\n                    load_tensor_to_gpu(param_group)\n        else:\n            load_tensor_to_gpu(group)\n\n    # Load all parameter groups to GPU for each underlying optimizer\n\n    for _opt in _iter_opts(optimizers):\n        if hasattr(_opt, \"shard_fp32_from_float16_groups\"):\n            load_group_to_gpu(_opt.shard_fp32_from_float16_groups)\n\n\n@torch.no_grad()\ndef offload_megatron_optimizer(optimizers):\n    def _iter_opts(opt):\n        if isinstance(opt, ChainedOptimizer):\n            return opt.chained_optimizers\n        return [opt]\n\n    for _opt in _iter_opts(optimizers):\n        offload_megatron_copy_params(_opt)\n        opt_state_dict_values = _opt.optimizer.state.values()\n        for v in opt_state_dict_values:\n            if \"exp_avg\" in v:\n                v[\"exp_avg\"] = v[\"exp_avg\"].to(\"cpu\", non_blocking=True)\n            if \"exp_avg_sq\" in v:\n                v[\"exp_avg_sq\"] = v[\"exp_avg_sq\"].to(\"cpu\", non_blocking=True)\n        gc.collect()\n        get_torch_device().empty_cache()\n\n\n@torch.no_grad()\ndef load_megatron_optimizer(optimizers):\n    def _iter_opts(opt):\n        if isinstance(opt, ChainedOptimizer):\n            return opt.chained_optimizers\n        return [opt]\n\n    for _opt in _iter_opts(optimizers):\n        load_megatron_copy_params(_opt)\n        opt_state_dict_values = _opt.optimizer.state.values()\n        for v in opt_state_dict_values:\n            if \"exp_avg\" in v:\n                v[\"exp_avg\"] = v[\"exp_avg\"].to(get_device_id(), non_blocking=True)\n            if \"exp_avg_sq\" in v:\n                v[\"exp_avg_sq\"] = v[\"exp_avg_sq\"].to(get_device_id(), non_blocking=True)\n        gc.collect()\n        get_torch_device().empty_cache()\n\n\ndef print_rank_0(message):\n    \"\"\"If distributed is initialized, print only on rank 0.\"\"\"\n    if torch.distributed.is_initialized():\n        if torch.distributed.get_rank() == 0:\n            print(message, flush=True)\n    else:\n        print(message, flush=True)\n\n\ndef get_dist_checkpoint_path(checkpoint_path):\n    local_mkdir_safe(checkpoint_path)\n    local_mkdir_safe(os.path.join(checkpoint_path, \"dist_ckpt\"))\n    return os.path.join(checkpoint_path, \"dist_ckpt\")\n\n\ndef get_hf_model_checkpoint_path(checkpoint_path):\n    local_mkdir_safe(checkpoint_path)\n    local_mkdir_safe(os.path.join(checkpoint_path, \"huggingface\"))\n    return os.path.join(checkpoint_path, \"huggingface\")\n\n\ndef get_transformer_config_checkpoint_path(checkpoint_path):\n    os.makedirs(checkpoint_path, exist_ok=True)\n    return os.path.join(checkpoint_path, \"transformer_config.json\")\n\n\ndef get_model_checkpoint_path(checkpoint_path):\n    os.makedirs(checkpoint_path, exist_ok=True)\n    return os.path.join(checkpoint_path, \"model\")\n\n\ndef get_hf_config_and_tokenizer_checkpoint_path(checkpoint_path):\n    os.makedirs(checkpoint_path, exist_ok=True)\n    return os.path.join(checkpoint_path, \"hf_config_and_tokenizer\")\n\n\ndef get_optimizer_checkpoint_path(checkpoint_path, use_distributed_optimizer=True):\n    os.makedirs(os.path.join(checkpoint_path, \"optim\"), exist_ok=True)\n    if not use_distributed_optimizer:\n        return os.path.join(checkpoint_path, \"optim\", \"optim.pt\")\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    tp_rank = mpu.get_tensor_model_parallel_rank()\n    cp_rank = mpu.get_context_parallel_rank()\n    dp_rank = mpu.get_data_parallel_rank()\n    # TODO: support ep\n    return os.path.join(checkpoint_path, \"optim\", f\"distrib_optim_pp{pp_rank}_tp{tp_rank}_cp{cp_rank}_dp{dp_rank}.pt\")\n\n\ndef get_rng_states_checkpoint_path(checkpoint_path, only_rank0_save=True):\n    # save rng states cause interrupts\n    os.makedirs(os.path.join(checkpoint_path, \"rng_states\"), exist_ok=True)\n    if only_rank0_save:\n        return os.path.join(checkpoint_path, \"rng_states\", \"rng_states.pt\")\n    dp_rank = mpu.get_data_parallel_rank()\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    tp_rank = mpu.get_tensor_model_parallel_rank()\n    cp_rank = mpu.get_context_parallel_rank()\n    return os.path.join(checkpoint_path, \"rng_states\", f\"rng_states_pp{pp_rank}_tp{tp_rank}_cp{cp_rank}_dp{dp_rank}.pt\")\n\n\ndef convert_megatron_model_to_transformers_model(\n    name,\n    param,\n    config: PretrainedConfig,\n    tp_size: int,\n    num_query_groups: int,\n    convert_qkv_gate_up_by_trunk_concat=False,\n):\n    \"\"\"Convert megatron model to transformers model.\"\"\"\n    new_params = {}\n\n    def convert_qkv_shard(full_tensor, q_name, k_name, v_name):\n        nonlocal config\n        nonlocal tp_size\n        nonlocal num_query_groups\n\n        q_shard_list = []\n        k_shard_list = []\n        v_shard_list = []\n        hidden_size_per_head = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n\n        if config.num_key_value_heads >= tp_size:\n            q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n            kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size\n            total_size = q_size_tp + 2 * kv_size_tp\n            for i in range(tp_size):\n                num_query_groups_per_partition = num_query_groups // tp_size\n                qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                q_size_chunk = q_size_tp // num_query_groups_per_partition\n                kv_size_chunk = kv_size_tp // num_query_groups_per_partition\n                for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):\n                    q_part = qkv_part_chunk[:q_size_chunk]\n                    k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]\n                    v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]\n                    q_shard_list.append(q_part)\n                    k_shard_list.append(k_part)\n                    v_shard_list.append(v_part)\n        else:\n            q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size\n            kv_size_tp = hidden_size_per_head\n            total_size = q_size_tp + 2 * kv_size_tp\n            for i in range(tp_size):\n                num_query_groups_per_partition = num_query_groups // tp_size\n                qkv_part = full_tensor[i * total_size : (i + 1) * total_size]\n                q_size_chunk = q_size_tp // num_query_groups_per_partition\n                kv_size_chunk = kv_size_tp // num_query_groups_per_partition\n                for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):\n                    q_part = qkv_part_chunk[:q_size_chunk]\n                    k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk]\n                    v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :]\n                    q_shard_list.append(q_part)\n                    if i * config.num_key_value_heads % tp_size == 0:\n                        k_shard_list.append(k_part)\n                        v_shard_list.append(v_part)\n\n        new_params[q_name] = torch.cat(q_shard_list, dim=0)\n        new_params[k_name] = torch.cat(k_shard_list, dim=0)\n        new_params[v_name] = torch.cat(v_shard_list, dim=0)\n\n    def convert_gate_up_shard(full_tensor, gate_name, up_name):\n        nonlocal config\n        nonlocal tp_size\n\n        intermediate_size_tp = config.intermediate_size // tp_size\n        gate_weight_list = []\n        up_weight_list = []\n        for i in range(tp_size):\n            gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)]\n            gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]\n            up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]\n            gate_weight_list.append(gate_weight_tp)\n            up_weight_list.append(up_weight_tp)\n\n        new_params[gate_name] = torch.cat(gate_weight_list, dim=0)\n        new_params[up_name] = torch.cat(up_weight_list, dim=0)\n\n    if name == \"embedding.word_embeddings.weight\":\n        new_params[\"model.embed_tokens.weight\"] = param\n    elif \"self_attention\" in name:\n        splitted_name = name.split(\".\")\n        layer_number = splitted_name[2]\n        component = splitted_name[4]\n        param_type = splitted_name[5]\n        if component == \"linear_proj\":\n            new_params[f\"model.layers.{layer_number}.self_attn.o_proj.weight\"] = param\n        elif component == \"linear_qkv\" and not isinstance(param, list):\n            if param_type == \"layer_norm_weight\":\n                new_params[f\"model.layers.{layer_number}.input_layernorm.weight\"] = param\n            else:\n                if convert_qkv_gate_up_by_trunk_concat:\n                    convert_qkv_shard(\n                        param,\n                        f\"model.layers.{layer_number}.self_attn.q_proj.{param_type}\",\n                        f\"model.layers.{layer_number}.self_attn.k_proj.{param_type}\",\n                        f\"model.layers.{layer_number}.self_attn.v_proj.{param_type}\",\n                    )\n                else:\n                    new_params[f\"model.layers.{layer_number}.self_attn.qkv_proj.{param_type}\"] = param\n        elif component == \"q_layernorm\" or component == \"k_layernorm\":\n            hf_component = component.replace(\"layer\", \"\")\n            new_params[f\"model.layers.{layer_number}.self_attn.{hf_component}.weight\"] = param\n        else:\n            assert isinstance(param, list) and len(param) == 3\n            assert param_type == \"weight\" or param_type == \"bias\"\n            new_params[f\"model.layers.{layer_number}.self_attn.q_proj.{param_type}\"] = param[0]\n            new_params[f\"model.layers.{layer_number}.self_attn.k_proj.{param_type}\"] = param[1]\n            new_params[f\"model.layers.{layer_number}.self_attn.v_proj.{param_type}\"] = param[2]\n    elif \"mlp\" in name:\n        splitted_name = name.split(\".\")\n        layer_number = splitted_name[2]\n        component = splitted_name[4]\n        param_type = splitted_name[5]\n        if component == \"linear_fc1\" and not isinstance(param, list):\n            if param_type == \"layer_norm_weight\":\n                new_params[f\"model.layers.{layer_number}.post_attention_layernorm.weight\"] = param\n            elif param_type == \"weight\":\n                if convert_qkv_gate_up_by_trunk_concat:\n                    convert_gate_up_shard(\n                        param,\n                        f\"model.layers.{layer_number}.mlp.gate_proj.weight\",\n                        f\"model.layers.{layer_number}.mlp.up_proj.weight\",\n                    )\n                else:\n                    new_params[f\"model.layers.{layer_number}.mlp.gate_up_proj.weight\"] = param\n        elif component == \"linear_fc1\" and isinstance(param, list):\n            assert len(param) == 2\n            assert param_type == \"weight\" or param_type == \"bias\"\n            new_params[f\"model.layers.{layer_number}.mlp.gate_proj.weight\"] = param[0]\n            new_params[f\"model.layers.{layer_number}.mlp.up_proj.weight\"] = param[1]\n        elif component == \"linear_fc2\":\n            new_params[f\"model.layers.{layer_number}.mlp.down_proj.weight\"] = param\n    elif name == \"decoder.final_layernorm.weight\":\n        new_params[\"model.norm.weight\"] = param\n    elif name == \"output_layer.weight\":\n        new_params[\"lm_head.weight\"] = param\n    else:\n        raise ValueError(f\"Unknown param name: {name}\")\n    return new_params.keys(), new_params.values()\n\n\ndef broadcast_from_megatron_pp(tensor: torch.Tensor):\n    # tensor is not None only in one of the pp ranks\n    if tensor is not None:\n        shape = tensor.shape\n        dtype = tensor.dtype\n        tensor_parallel = getattr(tensor, \"tensor_model_parallel\", None)\n        partition_dim = getattr(tensor, \"partition_dim\", None)\n        tensor_spec = (shape, dtype, tensor_parallel, partition_dim)\n    else:\n        tensor_spec = None\n    tensor_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size()\n    torch.distributed.all_gather_object(object_list=tensor_spec_output, obj=tensor_spec, group=mpu.get_pipeline_model_parallel_group())\n    # find the src rank\n    target_tensor_spec = None\n    src_rank = None\n    for rank, tensor_spec in enumerate(tensor_spec_output):\n        if tensor_spec is not None:\n            if target_tensor_spec is None:\n                target_tensor_spec = tensor_spec\n            else:\n                raise ValueError(\"A tensor exists on two pp ranks\")\n            src_rank = rank\n    assert target_tensor_spec is not None\n    if tensor is None:\n        tensor = torch.empty(size=target_tensor_spec[0], dtype=target_tensor_spec[1], device=get_device_id())\n        if target_tensor_spec[2] is not None:\n            tensor.tensor_model_parallel = target_tensor_spec[2]\n        if target_tensor_spec[3] is not None:\n            tensor.partition_dim = target_tensor_spec[3]\n\n    global_rank = torch.distributed.get_global_rank(group=mpu.get_pipeline_model_parallel_group(), group_rank=src_rank)\n    torch.distributed.broadcast(tensor=tensor, src=global_rank, group=mpu.get_pipeline_model_parallel_group())\n    return tensor\n\n\ndef broadcast_str_from_megatron_pp(obj: Any):\n    obj_output = [None] * mpu.get_pipeline_model_parallel_world_size()\n    torch.distributed.all_gather_object(object_list=obj_output, obj=obj, group=mpu.get_pipeline_model_parallel_group())\n\n    src_rank = None\n    target_obj = None\n    for rank, item in enumerate(obj_output):\n        if item is not None:\n            if target_obj is not None:\n                raise ValueError(\"An object exists on two pp ranks\")\n            target_obj = item\n            src_rank = rank\n\n    assert target_obj is not None, \"No valid object found to broadcast.\"\n\n    global_rank = torch.distributed.get_global_rank(group=mpu.get_pipeline_model_parallel_group(), group_rank=src_rank)\n\n    obj_output = [None] * torch.distributed.get_world_size(group=mpu.get_pipeline_model_parallel_group())\n    obj_output[0] = target_obj\n    torch.distributed.broadcast_object_list(object_list=obj_output, src=global_rank, group=mpu.get_pipeline_model_parallel_group())\n\n    return obj_output[0]\n\n\ndef default_tp_concat_fn(layer_name_mapping, name, train_params, infer_params, model_config, convert_qkv_gate_up_by_simple_split=False):\n    \"\"\"\n    name: name of the parameter\n    train_params: training parameters\n    infer_params (Iterable[torch.Tensor]): a iterator towards list of parameters all-gathered from micro_dp_group\n    model_config: huggingface model_config\n    TODO(zhangchi.usc1992): currently, the implementation is adhoc. We can move this function to the model\n    definition so that it is model-agnostic. If the model doesn't implement this function,\n    we can throw an error to force user disable TP HybridEngine.\n    \"\"\"\n    from megatron.core import mpu\n\n    if layer_name_mapping.get(\"qkv_layer_name\") in name and \"layer_norm\" not in name:\n        # if the tensor is qkv, for each param on tp, split into q, k, v\n        # concat q, k, v separately.\n        q_lst = []\n        k_lst = []\n        v_lst = []\n        assert model_config.num_attention_heads % model_config.num_key_value_heads == 0\n        num_q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads\n        assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0, f\"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}\"\n        kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2)\n        split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp]\n        for infer_param in infer_params:\n            num_query_groups_per_partition = model_config.num_key_value_heads // mpu.get_tensor_model_parallel_world_size()\n            for chunk in infer_param.chunk(num_query_groups_per_partition):\n                split_size = [kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, kv_size_per_tp // num_query_groups_per_partition, kv_size_per_tp // num_query_groups_per_partition]\n                q, k, v = chunk.split(split_size)\n                q_lst.append(q)\n                k_lst.append(k)\n                v_lst.append(v)\n        q = torch.cat(q_lst, dim=0)\n        k = torch.cat(k_lst, dim=0)\n        v = torch.cat(v_lst, dim=0)\n        infer_params = torch.cat((q, k, v), dim=0) if not convert_qkv_gate_up_by_simple_split else [q, k, v]\n\n    elif layer_name_mapping.get(\"gate_proj_layer_name\") in name:\n        # if the tensor is gate and proj\n        gate_lst = []\n        up_lst = []\n        for infer_param in infer_params:\n            gate, up = infer_param.chunk(2)\n            gate_lst.append(gate)\n            up_lst.append(up)\n        gate = torch.cat(gate_lst, dim=0)\n        up = torch.cat(up_lst, dim=0)\n        infer_params = torch.cat((gate, up), dim=0) if not convert_qkv_gate_up_by_simple_split else [gate, up]\n\n    elif \"mlp.experts.linear_fc2.weight\" in name:  # moe\n        infer_params = torch.cat(infer_params, dim=1)\n\n    else:\n        # concat tensor\n        infer_params = torch.cat(infer_params, dim=tp_utils.get_tensor_parallel_partition_dim(train_params))\n\n    return infer_params\n\n\ndef per_tensor_generator(actor_module, model_config, weight_converter, transformer_config, layer_name_mapping, convert_qkv_gate_up_by_simple_split=True):\n    from megatron.core import parallel_state as mpu\n\n    pp_rank = mpu.get_pipeline_model_parallel_rank()\n    ep_size = mpu.get_expert_model_parallel_world_size()\n    etp_size = mpu.get_expert_tensor_parallel_world_size()\n    ep_group = mpu.get_expert_model_parallel_group()\n    etp_group = mpu.get_expert_tensor_parallel_group()\n    vpp_size = len(actor_module)\n    all_gather_group = mpu.get_tensor_model_parallel_group()\n    all_gather_group_size = torch.distributed.get_world_size(group=all_gather_group)\n\n    def tensor_generator():\n        for scan_vpp_idx in range(vpp_size):\n            existing_keys = set()\n            model = unwrap_model(actor_module[scan_vpp_idx])\n            for name, param in model.named_parameters():\n                existing_keys.add(name)\n                yield name, param\n            # note\n            # there is a bug in megatron GPTModel\n            # decoder.layers[n].mlp.router.expert_bias\" in GPTModel is not registered in named_parameter, but in state_dict().\n            # for now we patch it by adding those keys to extra_keys.\n            extra_keys = [x for x in model.state_dict().keys() if \"_extra_state\" not in x and x not in existing_keys]\n            for name in extra_keys:\n                yield name, model.state_dict()[name].to(get_device_id())\n\n    # we need first make all rank get full model information\n    meta_info = []\n    for scan_vpp_idx in range(vpp_size):\n        existing_keys = set()\n        model = unwrap_model(actor_module[scan_vpp_idx])\n        for idx, (name, _) in enumerate(model.named_parameters()):\n            existing_keys.add(name)\n            meta_info.append((pp_rank, scan_vpp_idx, idx, name))\n        extra_keys = [x for x in model.state_dict().keys() if \"_extra_state\" not in x and x not in existing_keys]\n        for name in extra_keys:\n            meta_info.append((pp_rank, scan_vpp_idx, idx, name))\n\n    obj_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size()\n    torch.distributed.all_gather_object(object_list=obj_spec_output, obj=meta_info, group=mpu.get_pipeline_model_parallel_group())\n    layer_list_meta = [item for sublist in obj_spec_output for item in sublist]\n\n    gen_func = tensor_generator()\n\n    # lazy load tensor for full model\n    for cur_pp_rank, scan_vpp_idx, idx, name in layer_list_meta:\n        if model_config.tie_word_embeddings and (\"output_layers\" in name):\n            import warnings\n\n            warnings.warn(\"Current model sharing word and embedding weights, skip output layer conversion\", stacklevel=2)\n            continue\n\n        if cur_pp_rank == pp_rank:\n            try:\n                cur_name, cur_tensor = next(gen_func)\n            except StopIteration:\n                cur_name, cur_tensor = None, None\n            cur_name = normalize_model_name(name, cur_pp_rank, scan_vpp_idx, transformer_config)\n        else:\n            cur_tensor, cur_name = None, None\n\n        # pp broadcast model tensor and name\n        cur_name = broadcast_str_from_megatron_pp(cur_name)\n        broad_pp_tensor = broadcast_from_megatron_pp(cur_tensor)\n\n        # (xya): this is a hack to fix the name of the parameters\n        while cur_name.startswith(\"module.\"):\n            cur_name = cur_name[len(\"module.\") :]\n\n        # EP\n        if \".mlp.experts.linear_fc\" in cur_name and ep_size > 1:\n            num_experts = weight_converter.mcore_config.num_moe_experts\n            num_experts_per_rank = num_experts // ep_size\n            infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(ep_size)]\n            torch.distributed.all_gather(infer_params, broad_pp_tensor, group=ep_group)\n\n            name_prefix, local_expert_id = cur_name.split(\".weight\")\n            local_expert_id = int(local_expert_id)\n            global_expert_ids = [num_experts_per_rank * ep_rank + local_expert_id for ep_rank in range(ep_size)]\n            global_expert_names = [f\"{name_prefix}.weight{expert_id}\" for expert_id in global_expert_ids]\n\n            for name, param in zip(global_expert_names, infer_params):\n                if etp_size > 1:\n                    # gather etp\n                    etp_params = [torch.empty_like(param) for _ in range(etp_size)]\n                    torch.distributed.all_gather(etp_params, param, group=etp_group)\n                    params = etp_params\n                else:\n                    params = [param]\n\n                merge_params = default_tp_concat_fn(layer_name_mapping, name, broad_pp_tensor, params, model_config, convert_qkv_gate_up_by_simple_split)\n                if not isinstance(merge_params, list):\n                    merge_params = [merge_params]\n                converted_names, converted_params = weight_converter.convert_param(name, merge_params)\n\n                yield from zip(converted_names, converted_params)\n            continue\n\n        # tp all gather\n        if tp_utils.is_tensor_parallel_param(broad_pp_tensor):\n            # allocate a new tensor with proper size\n            if all_gather_group_size <= 1:\n                infer_params = [broad_pp_tensor]\n            else:\n                infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(all_gather_group_size)]\n                torch.distributed.all_gather(infer_params, broad_pp_tensor, group=mpu.get_tensor_model_parallel_group())\n            infer_params = default_tp_concat_fn(layer_name_mapping, cur_name, broad_pp_tensor, infer_params, model_config, convert_qkv_gate_up_by_simple_split)\n        else:\n            infer_params = broad_pp_tensor\n\n        if not isinstance(infer_params, list):\n            infer_params = [infer_params]\n        converted_names, converted_params = weight_converter.convert_param(cur_name, infer_params)\n\n        yield from zip(converted_names, converted_params)\n\n\ndef get_transformer_layer_offset(pipeline_rank, vp_rank, config: TransformerConfig):\n    '''\n    Get the index offset of any pipeline stage, given the level of pipelining.\n\n    Make pp_rank and vpp_rank as two arguments to make it more flexible,\n    which is able to fetch layer offset for any pipeline stage.\n    The original function only returns the layer offset for current pipeline stage.\n\n    Extension to https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/transformer_layer.py::get_transformer_layer_offset\"\"\"\n    '''\n    if config.pipeline_model_parallel_size > 1:\n        if config.num_layers_in_first_pipeline_stage is not None or config.num_layers_in_last_pipeline_stage is not None:\n            # Calculate number of pipeline stages to distribute the remaining Transformer\n            # layers after deducting the Transformer layers in the first or the last stages\n            middle_pipeline_stages = config.pipeline_model_parallel_size\n            middle_pipeline_stages -= sum(\n                [\n                    1 if x is not None else 0\n                    for x in (\n                        config.num_layers_in_first_pipeline_stage,\n                        config.num_layers_in_last_pipeline_stage,\n                    )\n                ]\n            )\n\n            # Calculate layers to distribute in each pipeline stage. If the\n            # num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage\n            # are not set, we will not enable uneven pipeline. All layers will be treated\n            # as middle layers.\n            num_layers_in_first_pipeline_stage = 0 if config.num_layers_in_first_pipeline_stage is None else config.num_layers_in_first_pipeline_stage\n            num_layers_in_last_pipeline_stage = 0 if config.num_layers_in_last_pipeline_stage is None else config.num_layers_in_last_pipeline_stage\n\n            middle_num_layers = config.num_layers - num_layers_in_first_pipeline_stage - num_layers_in_last_pipeline_stage\n\n            if mpu.get_virtual_pipeline_model_parallel_world_size() is not None:\n                vp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n\n                # Calculate number of layers in each virtual model chunk\n                # If the num_layers_in_first_pipeline_stage and\n                # num_layers_in_last_pipeline_stage are not set, all pipeline stages\n                # will be treated as middle pipeline stages in the calculation\n                num_layers_per_virtual_model_chunk_in_first_pipeline_stage = 0 if config.num_layers_in_first_pipeline_stage is None else config.num_layers_in_first_pipeline_stage // vp_size\n\n                num_layers_per_virtual_model_chunk_in_last_pipeline_stage = 0 if config.num_layers_in_last_pipeline_stage is None else config.num_layers_in_last_pipeline_stage // vp_size\n\n                num_layers_per_vritual_model_chunk_in_middle_pipeline_stage = middle_num_layers // vp_size\n\n                # First stage + middle stage + last stage\n                total_virtual_chunks = num_layers_per_virtual_model_chunk_in_first_pipeline_stage + num_layers_per_vritual_model_chunk_in_middle_pipeline_stage + num_layers_per_virtual_model_chunk_in_last_pipeline_stage\n\n                # Calculate the layer offset with interleaved uneven pipeline parallelism\n                if pipeline_rank == 0:\n                    offset = vp_rank * total_virtual_chunks\n                else:\n                    offset = vp_rank * total_virtual_chunks + num_layers_per_virtual_model_chunk_in_first_pipeline_stage + (pipeline_rank - 1) * (num_layers_per_vritual_model_chunk_in_middle_pipeline_stage // middle_pipeline_stages)\n            else:\n                if middle_pipeline_stages > 0:\n                    num_layers_per_pipeline_rank = middle_num_layers // middle_pipeline_stages\n                else:\n                    num_layers_per_pipeline_rank = 0\n\n                middle_pipeline_rank = pipeline_rank if config.num_layers_in_first_pipeline_stage is None else pipeline_rank - 1\n\n                if pipeline_rank == 0:\n                    offset = 0\n                else:\n                    offset = (middle_pipeline_rank * num_layers_per_pipeline_rank) + num_layers_in_first_pipeline_stage\n        else:\n            num_layers = config.num_layers\n\n            # Increase the number of layers by one if we include the embedding (loss)\n            # layer into pipeline parallelism partition and placement\n            if config.account_for_embedding_in_pipeline_split:\n                num_layers += 1\n\n            if config.account_for_loss_in_pipeline_split:\n                num_layers += 1\n\n            num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size\n\n            if mpu.get_virtual_pipeline_model_parallel_world_size() is not None:\n                vp_size = mpu.get_virtual_pipeline_model_parallel_world_size()\n\n                num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size\n                total_virtual_chunks = num_layers // vp_size\n                offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank)\n\n                # Reduce the offset of embedding layer from the total layer number\n                if config.account_for_embedding_in_pipeline_split and not mpu.is_pipeline_first_stage():\n                    offset -= 1\n            else:\n                offset = pipeline_rank * num_layers_per_pipeline_rank\n\n                # Reduce the offset of embedding layer from the total layer number\n                if config.account_for_embedding_in_pipeline_split and not mpu.is_pipeline_first_stage():\n                    offset -= 1\n    else:\n        offset = 0\n    return offset\n"
  },
  {
    "path": "siirl/utils/megatron/memory.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\n\nfrom siirl.utils.extras.device import get_device_id\n\nclass MemoryBuffer:\n    def __init__(self, numel, numel_padded, dtype):\n        self.numel = numel\n        self.numel_padded = numel_padded\n        self.dtype = dtype\n        self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device=get_device_id(), requires_grad=False)\n\n    def zero(self):\n        \"\"\"Reset the buffer to zero.\"\"\"\n        self.data.zero_()\n\n    def get(self, shape, start_index):\n        \"\"\"Return a tensor with the input `shape` as a view into the\n        1-D data starting at `start_index`.\"\"\"\n        end_index = start_index + shape.numel()\n        assert end_index <= self.numel, \"requested tensor is out of the buffer range.\"\n        buffer_tensor = self.data[start_index:end_index]\n        buffer_tensor = buffer_tensor.view(shape)\n        return buffer_tensor\n"
  },
  {
    "path": "siirl/utils/megatron/memory_buffer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis file contains utilities to manipulate torch memory buffers\n\"\"\"\n\nfrom typing import Dict, List, Optional\n\nimport torch\nfrom torch import nn\n\nfrom siirl.utils.extras.device import get_device_name\n\n\nclass MemoryBuffer:\n    \"\"\"\n    A memory buffer is a contiguous torch tensor that may combine multiple tensors sharing with the underlying\n    memory. It must have a unique type to support this behavior.\n    \"\"\"\n\n    def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype, source: Optional[torch.Tensor] = None):\n        self.numel = numel\n        self.numel_padded = numel_padded\n        self.dtype = dtype\n        if source is not None:\n            self.data = source\n        else:\n            self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device=get_device_name(), requires_grad=False)\n\n    def zero(self):\n        \"\"\"Reset the buffer to zero.\"\"\"\n        self.data.zero_()\n\n    def get(self, shape, start_index):\n        \"\"\"Return a tensor with the input `shape` as a view into the\n        1-D data starting at `start_index`.\"\"\"\n        end_index = start_index + shape.numel()\n        assert end_index <= self.numel, \"requested tensor is out of the buffer range.\"\n        buffer_tensor = self.data[start_index:end_index]\n        buffer_tensor = buffer_tensor.view(shape)\n        return buffer_tensor\n\n\ndef calc_padded_numel(shape: torch.Size, dtype: torch.dtype):\n    \"\"\"for cuda memory alignment, make sure alignment by 128-bits\"\"\"\n    align_numel = 128 // torch.finfo(dtype).bits\n    numel = shape.numel()\n    return (numel + align_numel - 1) // align_numel * align_numel\n\n\ndef get_weight_buffer_meta_from_module(module: nn.Module) -> Dict[str, Dict]:\n    \"\"\"\n    Return a dictionary containing name to a shape and dtype.\n    \"\"\"\n    weight_buffer_meta = {}\n    for name, param in sorted(module.named_parameters()):\n        weight_buffer_meta[name] = {\"shape\": param.shape, \"dtype\": param.dtype}\n    return weight_buffer_meta\n\n\ndef build_memory_buffer(weight_buffer_meta: Dict[str, Dict]) -> Dict[torch.dtype, MemoryBuffer]:\n    \"\"\"Build the memory buffer given weight_buffer_meta\n\n    Args:\n        weight_buffer_meta: contains mapping from name to a dictionary containing shape and dtype of the tensors\n\n    Returns: a large memory buffer for each dtype that can hold all the tensors\n\n    \"\"\"\n    memory_buffers = {}\n    total_numel_map = {}  # map from dtype to the total numel\n    for name, meta_info in sorted(weight_buffer_meta.items()):\n        shape = meta_info[\"shape\"]\n        dtype = meta_info[\"dtype\"]\n\n        assert isinstance(shape, torch.Size)\n        assert isinstance(dtype, torch.dtype)\n\n        if dtype not in total_numel_map:\n            total_numel_map[dtype] = 0\n\n        total_numel_map[dtype] += calc_padded_numel(shape, dtype)\n\n    for dtype, total_numel in total_numel_map.items():\n        memory_buffers[dtype] = MemoryBuffer(total_numel, total_numel, dtype)\n\n    return memory_buffers\n\n\ndef build_memory_reference_from_module(module: torch.nn.Module, memory_buffers: Dict[torch.dtype, MemoryBuffer], maintain_weight=True):\n    start_index = {}\n    for dtype in memory_buffers:\n        start_index[dtype] = 0\n    for name, param in sorted(module.named_parameters()):\n        memory_buffer = memory_buffers[param.dtype]\n        buffer = memory_buffer.get(shape=param.shape, start_index=start_index[param.dtype])\n        # need to increment start_index\n        start_index[param.dtype] += calc_padded_numel(param.shape, param.dtype)\n        if maintain_weight:\n            buffer.copy_(param.data)\n        param.data = buffer\n\n\ndef build_memory_reference(weight_buffer_meta: Dict[str, Dict], memory_buffers: Dict[torch.dtype, MemoryBuffer]):\n    \"\"\"Build the memory references. The memory buffers are built using the build_memory_buffer API.\n    This API will allocate a weight buffer pointer to the memory buffer according to the weight_buffer_meta.\n\n    Args:\n        weight_buffer_meta:\n        memory_buffers:\n\n    Returns:\n\n    \"\"\"\n    start_idx = {}\n    weight_buffers = {}\n    for dtype in memory_buffers:\n        start_idx[dtype] = 0\n\n    for name, meta_info in sorted(weight_buffer_meta.items()):\n        shape = meta_info[\"shape\"]\n        dtype = meta_info[\"dtype\"]\n\n        buffer = memory_buffers[dtype].get(shape, start_index=start_idx[dtype])\n        start_idx[dtype] += calc_padded_numel(shape, dtype)\n        weight_buffers[name] = buffer\n\n    return weight_buffers\n\n\nclass MemoryBufferModuleWrapper:\n    \"\"\"\n    Note that we do not design MemoryBufferModuleWrapper as an nn.Module due to\n    - It will change the checkpoint name\n    \"\"\"\n\n    def __init__(self, module: nn.Module):\n        super().__init__()\n        self.module = module\n        self.weight_buffer_meta = get_weight_buffer_meta_from_module(self.module)\n        self.memory_buffers = build_memory_buffer(self.weight_buffer_meta)\n        build_memory_reference_from_module(self.module, self.memory_buffers)\n\n    def get_memory_buffers(self):\n        return self.memory_buffers\n\n    def get_weight_buffer_meta(self):\n        return self.weight_buffer_meta\n\n\nclass MegatronMemoryBufferForRollout:\n    \"\"\"\n    We assume that\n    - inference engine has tp + dp\n    - actor has tp + pp + dp\n    - the tp between inference engine and actor should be the same\n    - memory_buffers: contains a list of memory_buffers, each is a dict from dtype to MemoryBuffer\n    - weight_buffers: contains a list of weight_buffers, each is a dict from name to param\n    - named_parameters: a dict from name to parameter that normalizes the names from pp and vpp. Note that\n        the named_parameters may not be directly compatible with inference engine. User has to take care of\n        this part such as the layout mismatches. (e.g. qkv transpose)\n    - Note that weight_buffer, named_parameters and memory_buffers share the same underlying GPU memory.\n    - When doing weight sync, the data is transfer via memory buffers\n    \"\"\"\n\n    def __init__(self, transform_memory_param_fn):\n        self._memory_buffers = []\n        self._weight_buffers = []\n        self._named_parameters = {}\n        self.transform_memory_param_fn = transform_memory_param_fn\n\n    def initialize_weight_buffer(self, weight_buffer_meta_pp: List[Dict[str, Dict]]):\n        \"\"\"\n        Initialize the weight buffer. The weight buffer is obtained according to the actor. We will construct\n        a large buffer for each dtype in the weight_buffer.\n\n        Args:\n            weight_buffer_meta: contains pp models, each pp models contains a dictionary of mapping from\n\n        Returns: None\n\n        \"\"\"\n        self.weight_buffer_meta_pp = weight_buffer_meta_pp\n\n        for weight_buffer_meta in self.weight_buffer_meta_pp:\n            memory_buffer = build_memory_buffer(weight_buffer_meta)\n            self._memory_buffers.append(memory_buffer)\n            self._weight_buffers.append(None)\n\n    def build_memory_reference(self):\n        for i, weight_buffer_meta in enumerate(self.weight_buffer_meta_pp):\n            self._weight_buffers[i] = build_memory_reference(weight_buffer_meta, self._memory_buffers[i])\n        self._named_parameters = self.transform_memory_param_fn(self._weight_buffers)\n\n    @property\n    def named_parameters(self):\n        return self._named_parameters\n\n    @property\n    def weight_buffers(self):\n        return self._weight_buffers\n\n    @property\n    def memory_buffers(self):\n        return self._memory_buffers\n"
  },
  {
    "path": "siirl/utils/megatron/optimizer.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom megatron.core.optimizer import OptimizerConfig\nfrom megatron.core.optimizer import get_megatron_optimizer as get_megatron_optimizer_native\nfrom megatron.core.optimizer_param_scheduler import OptimizerParamScheduler\n\n\ndef get_megatron_optimizer(\n    model,\n    config: OptimizerConfig,\n    no_weight_decay_cond=None,\n    scale_lr_cond=None,\n    lr_mult=1.0,\n):\n    # Base optimizer.\n    return get_megatron_optimizer_native(\n        config=config,\n        model_chunks=model,\n        no_weight_decay_cond=no_weight_decay_cond,\n        scale_lr_cond=scale_lr_cond,\n        lr_mult=lr_mult,\n    )\n\n\ndef get_megatron_optimizer_param_scheduler(\n    optimizer,\n    config,\n):\n    \"\"\"\n    Get the optimizer parameter scheduler for Megatron.\n    \"\"\"\n    lr_decay_steps = config.lr_decay_steps\n    lr_warmup_steps = config.lr_warmup_steps\n    if config.lr_decay_steps is None:\n        lr_decay_steps = config.total_training_steps\n    wsd_decay_steps = None\n    if config.lr_wsd_decay_steps is not None:\n        wsd_decay_steps = config.lr_wsd_decay_steps\n    if config.lr_warmup_steps_ratio is not None and (\n        config.lr_warmup_steps is None or config.lr_warmup_steps <= 0\n    ):\n        lr_warmup_steps = int(config.lr_warmup_steps_ratio * lr_decay_steps)\n\n    opt_param_scheduler = OptimizerParamScheduler(\n        optimizer,\n        init_lr=config.lr_warmup_init,\n        max_lr=config.lr,\n        min_lr=config.min_lr,\n        lr_warmup_steps=lr_warmup_steps,\n        lr_decay_steps=lr_decay_steps,\n        lr_decay_style=config.lr_decay_style,\n        start_wd=config.weight_decay,\n        end_wd=config.weight_decay,\n        wd_incr_steps=config.total_training_steps,\n        wd_incr_style=config.weight_decay_incr_style,\n        use_checkpoint_opt_param_scheduler=config.use_checkpoint_opt_param_scheduler,\n        override_opt_param_scheduler=(not config.use_checkpoint_opt_param_scheduler),\n        wsd_decay_steps=wsd_decay_steps,\n        lr_wsd_decay_style=config.lr_wsd_decay_style,\n    )\n\n    return opt_param_scheduler\n\n\ndef get_megatron_last_lr(optimizer):\n    \"\"\"\n    Get the last learning rate from the optimizer parameter scheduler.\n    \"\"\"\n    return optimizer.param_groups[0][\"lr\"]\n"
  },
  {
    "path": "siirl/utils/megatron/pipeline_parallel.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom megatron.core import parallel_state as mpu\n\nfrom siirl.utils.megatron.sequence_parallel import pad_to_sequence_parallel\n\n\ndef compute_transformers_input_shapes(batches, meta_info):\n    from flash_attn.bert_padding import unpad_input  # flash 2 is a must for Megatron\n\n    # pre-compute input shapes for each micro-batch at each pp stage\n    input_shapes = []\n    for model_inputs in batches:\n        input_ids = model_inputs[\"input_ids\"]\n        attention_mask = model_inputs[\"attention_mask\"]\n        input_ids_rmpad = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)[0]  # (total_nnz, 1)\n        if meta_info[\"sequence_parallel\"]:\n            input_ids_rmpad = pad_to_sequence_parallel(input_ids_rmpad)\n            # compute shapes for model_inputs\n            input_shapes.append(\n                torch.Size(\n                    [\n                        input_ids_rmpad.shape[0] // mpu.get_tensor_model_parallel_world_size(),\n                        1,\n                        meta_info[\"hidden_size\"],\n                    ]\n                )\n            )\n        else:\n            # compute shapes for model_inputs\n            input_shapes.append(torch.Size([input_ids_rmpad.shape[0], 1, meta_info[\"hidden_size\"]]))\n    return input_shapes\n\n\ndef make_batch_generator(batches, vpp_size):\n    \"\"\"\n    Creates a batch generator suitable for Megatron pipeline parallelism,\n    handling virtual pipeline parallelism (VPP).\n\n    If VPP is used (vpp_size > 1), it duplicates the batch iterator for each\n    virtual pipeline stage. Otherwise, it returns a single iterator.\n\n    Args:\n        batches: An iterable (e.g., list) of micro-batches.\n        vpp_size (int): The virtual pipeline model parallel size.\n\n    Returns:\n        An iterator or a list of iterators over the micro-batches.\n    \"\"\"\n    if vpp_size > 1:\n        # has vpp\n        batch_generator = [batches] * vpp_size  # number of vpp chunks\n        batch_generator = [iter(b) for b in batch_generator]\n    else:\n        # no vpp\n        batch_generator = iter(batches)\n    return batch_generator\n"
  },
  {
    "path": "siirl/utils/megatron/sequence_parallel.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.nn.functional as F\nfrom megatron.core import parallel_state as mpu\n\n\ndef mark_parameter_as_sequence_parallel(parameter):\n    parameter.sequence_parallel = True\n\n\ndef is_sequence_parallel_param(param):\n    return hasattr(param, \"sequence_parallel\") and param.sequence_parallel\n\n\ndef pad_to_sequence_parallel(unpad_tokens: torch.Tensor):\n    \"\"\"pad the tokens such that the total length is a multiple of sp world size\n\n    Args:\n        unpad_tokens: (total_nnz, ...). Tokens after removing padding\n\n    Returns:\n        the padded tokens: (total_nnz + pad_size,...)\n\n    \"\"\"\n    total_nnz = unpad_tokens.shape[0]\n    sp_world_size = mpu.get_tensor_model_parallel_world_size()\n\n    pad_size = 0 if total_nnz % sp_world_size == 0 else sp_world_size - total_nnz % sp_world_size\n\n    if pad_size > 0:\n        if unpad_tokens.ndim == 1:\n            unpad_tokens = F.pad(unpad_tokens, (0, pad_size))\n        elif unpad_tokens.ndim == 2:\n            unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size))\n        else:\n            raise NotImplementedError(f\"Padding dim {unpad_tokens.ndim()} is not supported\")\n\n    return unpad_tokens\n"
  },
  {
    "path": "siirl/utils/megatron/tensor_parallel.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nUtilities for using tensor_parallel in megatron\n\"\"\"\n\nfrom typing import TYPE_CHECKING, Dict\n\nimport torch\nimport torch.distributed as dist\nfrom megatron.core import parallel_state as mpu\nfrom torch.nn import init\n\nif TYPE_CHECKING:\n    from megatron.core import ModelParallelConfig\n\n\ndef update_kwargs_with_config(dictionary: Dict, config: \"ModelParallelConfig\"):\n    dictionary[\"config\"] = config\n    return dictionary\n\n\ndef get_default_kwargs_for_model_parallel_config():\n    model_parallel_config_kwargs = {\n        \"params_dtype\": torch.float32,\n        \"use_cpu_initialization\": False,\n        \"perform_initialization\": True,\n        \"gradient_accumulation_fusion\": False,\n        \"sequence_parallel\": False,\n    }\n    return model_parallel_config_kwargs\n\n\ndef get_default_model_parallel_config():\n    from megatron.core import ModelParallelConfig\n\n    return ModelParallelConfig(**get_default_kwargs_for_model_parallel_config())\n\n\ndef get_common_default_kwargs_for_parallel_linear():\n    default_model_parallel_config = get_default_model_parallel_config()\n    common_default_kwargs = {\n        \"init_method\": init.xavier_normal_,\n        \"stride\": 1,\n        \"keep_master_weight_for_test\": False,\n        \"config\": default_model_parallel_config,\n    }\n    return common_default_kwargs\n\n\ndef get_default_kwargs_for_column_parallel_linear():\n    from megatron.core import ModelParallelConfig\n\n    model_parallel_config_kwargs = get_default_kwargs_for_model_parallel_config()\n    column_parallel_config_kwargs = {\n        \"async_tensor_model_parallel_allreduce\": False,\n    }\n    model_parallel_config_kwargs.update(column_parallel_config_kwargs)\n    column_default_kwargs = {\n        \"config\": ModelParallelConfig(**model_parallel_config_kwargs),\n    }\n    common_default_kwargs = get_common_default_kwargs_for_parallel_linear()\n    common_default_kwargs.update(column_default_kwargs)\n    return common_default_kwargs\n\n\ndef get_default_kwargs_for_row_parallel_linear():\n    common_default_kwargs = get_common_default_kwargs_for_parallel_linear()\n    return common_default_kwargs\n\n\ndef get_default_kwargs_for_parallel_embedding():\n    from megatron.core import ModelParallelConfig\n\n    model_parallel_config_kwargs = get_default_kwargs_for_model_parallel_config()\n    embedding_default_kwargs = {\n        \"init_method\": init.xavier_normal_,\n        \"config\": ModelParallelConfig(**model_parallel_config_kwargs),\n    }\n    return embedding_default_kwargs\n\n\ndef is_tensor_parallel_param(param):\n    return hasattr(param, \"tensor_model_parallel\") and param.tensor_model_parallel\n\n\ndef get_tensor_parallel_partition_dim(param):\n    assert is_tensor_parallel_param(param)\n    return param.partition_dim\n\n\ndef get_tensor_parallel_partition_stride(param):\n    assert is_tensor_parallel_param(param)\n    return param.partition_stride\n\n\nclass _VocabParallelEntropy(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor:\n        @torch.compile(dynamic=True)\n        def mul_reduce(a, b):\n            return (a * b).sum(dim=-1, keepdim=True)\n\n        logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values\n        dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=mpu.get_tensor_model_parallel_group())\n        normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max\n        normalized_exp_logits = normalized_vocab_parallel_logits.exp_()\n        normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True)\n        dist.all_reduce(normalized_sum_exp_logits, group=mpu.get_tensor_model_parallel_group())\n        softmax_logits = normalized_exp_logits.div_(normalized_sum_exp_logits)\n        sum_softmax_times_logits = mul_reduce(softmax_logits, vocab_parallel_logits)\n        dist.all_reduce(sum_softmax_times_logits, group=mpu.get_tensor_model_parallel_group())\n        entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits\n        ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits)\n        return entropy.squeeze(dim=-1)\n\n    @staticmethod\n    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:\n        vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors\n        # reuse softmax_logits as grad\n        vocab_parallel_logits.sub_(sum_softmax_times_logits)\n        softmax_logits.mul_(vocab_parallel_logits)\n        softmax_logits.mul_(grad_output.unsqueeze(dim=-1))\n        # recover vocab_parallel_logits\n        vocab_parallel_logits.add_(sum_softmax_times_logits)\n        softmax_logits.mul_(-1)\n        return softmax_logits\n\n\ndef vocab_parallel_entropy(vocab_parallel_logits: torch.Tensor) -> torch.Tensor:\n    \"\"\"Compute entropy when the logits are sharded in tp ranks\n\n    Args:\n        vocab_parallel_logits: (total_nnz, vocab_size // tp_size)\n\n    Returns: (total_nnz,)\n\n    \"\"\"\n    return _VocabParallelEntropy.apply(vocab_parallel_logits)\n\n\ndef vocab_parallel_log_probs_from_logits(logits, labels):\n    \"\"\"TODO(zhangchi.usc1992): We may change the implementation later\"\"\"\n    from megatron.core import tensor_parallel\n\n    return -tensor_parallel.vocab_parallel_cross_entropy(vocab_parallel_logits=logits, target=labels)\n\n\ndef vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length):\n    \"\"\"Similar to log_probs_from_logits_response_rmpad, but the logits_rmpad is now spliited across tensor parallel region.\n    This will further reduce the peak memory usage during training\n\n    Args:\n        input_ids: [batch_size, seqlen]\n        attention_mask: [batch_size, seqlen]\n        logits_rmpad: [total_nnz, vocab_size // tp_size]\n        response_length: int\n\n    \"\"\"\n    from flash_attn.bert_padding import pad_input, unpad_input\n\n    batch_size, seqlen = input_ids.shape\n    input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask)\n    input_ids_rmpad = input_ids_rmpad.squeeze(-1)\n    input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)\n    full_log_probs_rmpad = vocab_parallel_log_probs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)  # (total_nnz,)\n    full_output = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen)\n    output = full_output.squeeze(-1)[:, -response_length - 1 : -1]  # [batch_size, response_length]\n    return output\n"
  },
  {
    "path": "siirl/utils/memory_utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport gc\nimport logging\n\nfrom siirl.utils.extras.device import get_torch_device\n\nlogger = logging.getLogger(__name__)\n\n\ndef aggressive_empty_cache(force_sync: bool = True, max_retries: int = 3) -> None:\n    \"\"\"\n    More aggressive GPU memory cleanup function, tries to release PyTorch reserved but unallocated memory.\n\n    Args:\n        force_sync: Whether to force device synchronization\n        max_retries: Maximum number of retries\n    \"\"\"\n    device = get_torch_device()\n    if not device.is_available():\n        return\n\n    for attempt in range(max_retries):\n        # Record memory status before cleanup\n        before_reserved = device.memory_reserved()\n        before_allocated = device.memory_allocated()\n\n        # Run garbage collection\n        gc.collect()\n\n        # Clear PyTorch cache\n        device.empty_cache()\n\n        # Force synchronization (optional)\n        if force_sync:\n            device.synchronize()\n\n        # Record memory status after cleanup\n        after_reserved = device.memory_reserved()\n        after_allocated = device.memory_allocated()\n\n        # Calculate freed memory\n        reserved_freed = before_reserved - after_reserved\n        allocated_freed = before_allocated - after_allocated\n\n        logger.info(\n            f\"Memory cleanup attempt {attempt + 1}: Freed {reserved_freed / 1024**3:.2f} GB reserved, \"\n            f\"{allocated_freed / 1024**3:.2f} GB allocated\"\n        )\n\n        # Stop retrying if little memory was freed\n        if reserved_freed < 1024**3:  # less than 1GB\n            break\n\n\ndef reset_memory_stats() -> None:\n    \"\"\"Reset GPU memory statistics\"\"\"\n    if get_torch_device().is_available():\n        device = get_torch_device()\n        device.reset_peak_memory_stats()\n        device.reset_accumulated_memory_stats()\n\n\ndef get_memory_info() -> dict:\n    \"\"\"Get detailed GPU memory information\"\"\"\n    if not get_torch_device().is_available():\n        return {}\n\n    device = get_torch_device()\n    device_id = device.current_device()\n\n    return {\n        \"total_memory_gb\": device.get_device_properties(device_id).total_memory / 1024**3,\n        \"reserved_memory_gb\": device.memory_reserved() / 1024**3,\n        \"allocated_memory_gb\": device.memory_allocated() / 1024**3,\n        \"cached_memory_gb\": (device.memory_reserved() - device.memory_allocated()) / 1024**3,\n        \"max_memory_allocated_gb\": device.max_memory_allocated() / 1024**3,\n        \"max_memory_reserved_gb\": device.max_memory_reserved() / 1024**3,\n    }\n\n\ndef log_memory_usage(stage: str = \"current\") -> None:\n    \"\"\"Log GPU memory usage\"\"\"\n    if not get_torch_device().is_available():\n        return\n\n    info = get_memory_info()\n    logger.info(\n        f\"Memory usage [{stage}]: \"\n        f\"Total: {info['total_memory_gb']:.2f} GB, \"\n        f\"Allocated: {info['allocated_memory_gb']:.2f} GB, \"\n        f\"Reserved: {info['reserved_memory_gb']:.2f} GB, \"\n        f\"Cached: {info['cached_memory_gb']:.2f} GB\"\n    )\n\n\ndef optimize_memory_for_inference() -> None:\n    \"\"\"Optimize GPU memory usage for inference\"\"\"\n    if not get_torch_device().is_available():\n        return\n\n    # Set a more aggressive memory allocation policy\n    get_torch_device().set_per_process_memory_fraction(0.95)  # Use 95% of GPU memory\n\n    # Clear cache\n    aggressive_empty_cache(force_sync=True)\n\n    logger.info(\"Optimized GPU memory usage for inference\")\n\n\ndef optimize_memory_for_training() -> None:\n    \"\"\"Optimize GPU memory usage for training\"\"\"\n    if not get_torch_device().is_available():\n        return\n\n    # Set a moderate memory allocation policy\n    get_torch_device().set_per_process_memory_fraction(0.9)  # Use 90% of GPU memory\n\n    # Clear cache\n    aggressive_empty_cache(force_sync=False)\n\n    logger.info(\"Optimized GPU memory usage for training\")\n"
  },
  {
    "path": "siirl/utils/metrics/__init__.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "siirl/utils/metrics/metric_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nMetrics related to the PPO trainer.\n\"\"\"\n\nfrom collections import defaultdict\nfrom typing import Any, Dict, List, Callable\n\nimport psutil\nimport numpy as np\nimport torch\nimport pandas as pd\nimport ray\nfrom scipy.stats import mode\nimport logging\nimport os\nfrom pathlib import Path\nfrom datetime import datetime\nfrom tensordict import TensorDict\nfrom functools import partial\nimport json\n\n\ndef _compute_response_info(batch: TensorDict) -> Dict[str, Any]:\n    \"\"\"\n    Computes information about prompts and responses from a batch.\n\n    This is an internal helper function that extracts masks and lengths for prompts and responses.\n\n    Args:\n        batch: A TensorDict object containing batch data with responses and attention masks.\n\n    Returns:\n        A dictionary containing:\n            - response_mask: Attention mask for the response tokens\n            - prompt_length: Tensor of prompt lengths for each item in the batch\n            - response_length: Tensor of response lengths for each item in the batch\n    \"\"\"\n    response_length = batch[\"responses\"].shape[-1]\n\n    prompt_mask = batch[\"attention_mask\"][:, :-response_length]\n\n    if \"response_mask\" not in batch:\n        response_mask = batch[\"attention_mask\"][:, -response_length:]\n    else:\n        response_mask = batch[\"response_mask\"]\n\n    prompt_length = prompt_mask.sum(-1).float()\n    response_length = response_mask.sum(-1).float()  # (batch_size,)\n\n    return dict(\n        response_mask=response_mask,\n        prompt_length=prompt_length,\n        response_length=response_length,\n    )\n\ndef compute_data_metric(data: TensorDict):\n    \"\"\"  \n        Computes various metrics from a batch of data for PPO training.\n        This function calculates metrics related to scores, rewards, advantages, returns, values,\n        and sequence lengths from a batch of data. It provides statistical information (mean, max, min)\n        for each metric category.\n        Args:\n            batch: A TensorDict object containing batch data with token-level scores, rewards, advantages, etc.\n            use_critic: Whether to include critic-specific metrics. Defaults to True.\n\n        Returns:\n            A dictionary of metrics including:\n                - critic/score/mean, max, min: Statistics about sequence scores\n                - critic/rewards/mean, max, min: Statistics about sequence rewards\n                - critic/advantages/mean, max, min: Statistics about advantages\n                - critic/returns/mean, max, min: Statistics about returns\n                - critic/values/mean, max, min: Statistics about critic values (if use_critic=True)\n                - critic/vf_explained_var: Explained variance of the value function (if use_critic=True)\n                - response_length/mean, max, min, clip_ratio: Statistics about response lengths\n                - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths\n                - num_turns/mean, max, min: Statistics about the number of multi-turn conversations\n    \"\"\"\n    sequence_score = data[\"token_level_scores\"].sum(-1)\n    sequence_reward = data[\"token_level_rewards\"].sum(-1)\n\n    advantages = data[\"advantages\"]\n    returns = data[\"returns\"]\n\n    max_response_length = data[\"responses\"].shape[-1]\n    prompt_mask = data[\"attention_mask\"][:, :-max_response_length].bool()\n    response_mask = data[\"response_mask\"].bool()\n\n    max_prompt_length = prompt_mask.size(-1)\n\n    response_info = _compute_response_info(data)\n    prompt_length = response_info[\"prompt_length\"]\n    response_length = response_info[\"response_length\"]\n\n    valid_adv = torch.masked_select(advantages, response_mask)\n    valid_returns = torch.masked_select(returns, response_mask)\n    \n    if \"values\" in data:\n        values = data[\"values\"]\n        valid_values = torch.masked_select(values, response_mask)\n        return_diff_var = torch.var(valid_returns - valid_values)\n        return_var = torch.var(valid_returns)\n        \n    correct_threshold = 0.5\n    rewards_per_response = data[\"token_level_rewards\"].sum(-1)\n    correct_mask = rewards_per_response > correct_threshold\n    response_lengths = response_info[\"response_length\"]\n\n    # add by siirl\n    correct_response_length =  response_lengths[correct_mask]\n    wrong_response_length = response_lengths[~correct_mask]\n    \n    \n    metrics = {\n        # score\n        \"critic/score/mean\": torch.mean(sequence_score).detach().item(),\n        \"critic/score/max\": torch.max(sequence_score).detach().item(),\n        \"critic/score/min\": torch.min(sequence_score).detach().item(),\n        # reward\n        \"critic/rewards/mean\": torch.mean(sequence_reward).detach().item(),\n        \"critic/rewards/max\": torch.max(sequence_reward).detach().item(),\n        \"critic/rewards/min\": torch.min(sequence_reward).detach().item(),\n        # adv\n        \"critic/advantages/mean\": torch.mean(valid_adv).detach().item(),\n        \"critic/advantages/max\": torch.max(valid_adv).detach().item(),\n        \"critic/advantages/min\": torch.min(valid_adv).detach().item(),\n        # returns\n        \"critic/returns/mean\": torch.mean(valid_returns).detach().item(),\n        \"critic/returns/max\": torch.max(valid_returns).detach().item(),\n        \"critic/returns/min\": torch.min(valid_returns).detach().item(),\n        **(\n            {\n                # values\n                \"critic/values/mean\": torch.mean(valid_values).detach().item(),\n                \"critic/values/max\": torch.max(valid_values).detach().item(),\n                \"critic/values/min\": torch.min(valid_values).detach().item(),\n                # vf explained var\n                \"critic/vf_explained_var\": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),\n            }\n            if \"values\" in data\n            else {}\n        ),\n        # response length\n        \"response/length/mean\": torch.mean(response_length).detach().item(),\n        \"response/length/max\": torch.max(response_length).detach().item(),\n        \"response/length/min\": torch.min(response_length).detach().item(),\n        \"response/clip_ratio/mean\": torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(),\n        \"response/correct_length/max\": torch.max(correct_response_length).detach().item() if correct_response_length.numel() != 0 else 0,\n        \"response/correct_length/min\": torch.min(correct_response_length).detach().item() if correct_response_length.numel() != 0 else 0,\n        \"response/correct_length/mean\": torch.mean(correct_response_length).detach().item() if correct_response_length.numel() != 0 else 0,\n        \"response/wrong_length/max\": torch.max(wrong_response_length).detach().item() if wrong_response_length.numel() != 0 else 0,\n        \"response/wrong_length/min\": torch.min(wrong_response_length).detach().item() if wrong_response_length.numel() != 0 else 0,\n        \"response/wrong_length/mean\": torch.mean(wrong_response_length).detach().item() if wrong_response_length.numel() != 0 else 0,\n        \n        # prompt length\n        \"prompt/length/mean\": torch.mean(prompt_length).detach().item(),\n        \"prompt/length/max\": torch.max(prompt_length).detach().item(),\n        \"prompt/length/min\": torch.min(prompt_length).detach().item(),\n        \"prompt/clip_ratio/mean\": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),\n\n        # system_info\n        \"perf/process_cpu_mem_used_gb\" : psutil.Process(os.getpid()).memory_info().rss / (1024**3)\n    }\n    # multi-turn conversation\n    if \"__num_turns__\" in data:\n        num_turns = data[\"__num_turns__\"]\n        metrics[\"num_turns/min\"] = num_turns.min()\n        metrics[\"num_turns/max\"] = num_turns.max()\n        metrics[\"num_turns/mean\"] = num_turns.mean()\n    return metrics\n\n\ndef compute_timing_metrics(batch: TensorDict, timing_raw: Dict[str, float]) -> Dict[str, Any]:\n    \"\"\"\n    Computes timing metrics for different processing stages in PPO training.\n\n    This function calculates both raw timing metrics (in seconds) and per-token timing metrics\n    (in milliseconds) for various processing stages like generation, reference computation,\n    value computation, advantage computation, and model updates.\n\n    Args:\n        batch: A Tensordict object containing batch data with responses and attention masks.\n        timing_raw: A dictionary mapping stage names to their execution times in seconds.\n\n    Returns:\n        A dictionary containing:\n            - timing_s/{name}: Raw timing in seconds for each stage\n            - timing_per_token_ms/{name}: Per-token timing in milliseconds for each stage\n\n    Note:\n        Different stages use different token counts for normalization:\n        - \"gen\" uses only response tokens\n        - Other stages (\"ref\", \"values\", \"adv\", \"update_critic\", \"update_actor\") use all tokens\n          (prompt + response)\n    \"\"\"\n    response_info = _compute_response_info(batch)\n    num_prompt_tokens = torch.sum(response_info[\"prompt_length\"]).item()\n    num_response_tokens = torch.sum(response_info[\"response_length\"]).item()\n    num_overall_tokens = num_prompt_tokens + num_response_tokens\n\n    num_tokens_of_section = {\n        \"gen\": num_response_tokens,\n        **{name: num_overall_tokens for name in [\"ref\", \"values\", \"adv\", \"update_critic\", \"update_actor\"]},\n    }\n\n    return {\n        **{f\"timing_s/{name}\": value for name, value in timing_raw.items()},\n        **{f\"timing_per_token_ms/{name}\": timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())},\n    }\n\n\ndef compute_throughout_metrics(batch: TensorDict, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]:\n    \"\"\"\n    Computes throughput metrics for PPO training.\n\n    This function calculates performance metrics related to token processing speed,\n    including the total number of tokens processed, time per step, and throughput\n    (tokens per second per GPU).\n\n    Args:\n        batch: A TensorDict object containing batch data with meta information about token counts.\n        timing_raw: A dictionary mapping stage names to their execution times in seconds.\n                   Must contain a \"step\" key with the total step time.\n        n_gpus: Number of GPUs used for training.\n\n    Returns:\n        A dictionary containing:\n            - perf/total_num_tokens: Total number of tokens processed in the batch\n            - perf/time_per_step: Time taken for the step in seconds\n            - perf/throughput: Tokens processed per second per GPU\n\n    Note:\n        The throughput is calculated as total_tokens / (time * n_gpus) to normalize\n        across different GPU counts.\n    \"\"\"\n    total_num_tokens = sum(batch[\"global_token_num\"])\n    time = timing_raw[\"step\"]\n    # estimated_flops, promised_flops = flops_function.estimate_flops(num_tokens, time)\n    # f'Actual TFLOPs/s/GPU​': estimated_flops/(n_gpus),\n    # f'Theoretical TFLOPs/s/GPU​': promised_flops,\n    return {\n        \"perf/total_num_tokens\": total_num_tokens,\n        \"perf/time_per_step\": time,\n        \"perf/throughput\": total_num_tokens / (time * n_gpus),\n    }\n\n\ndef _calculate_bootstrap_metrics(group: pd.DataFrame, variable_name: str, subset_size: int, n_bootstrap: int = 1000) -> Dict[str, float]:\n    \"\"\"Performs fully vectorized bootstrap sampling to estimate statistics.\n\n    This is the core computational engine. It avoids all Python loops by using\n    NumPy's vectorized indexing and Scipy's vectorized mode calculation to\n    efficiently compute metrics for thousands of bootstrap samples at once.\n\n    Args:\n        group: DataFrame containing the data for a single prompt, including the\n               target variable column and potentially a 'pred' column.\n        variable_name: The name of the column to perform bootstrap sampling on.\n        subset_size: The size of each bootstrap sample (referred to as 'N').\n        n_bootstrap: The number of bootstrap iterations to perform.\n\n    Returns:\n        A dictionary containing the calculated mean and standard deviation for\n        best-of-N, worst-of-N, and majority-vote-of-N metrics.\n    \"\"\"\n    metrics = {}\n    variable_values = group[variable_name].to_numpy()\n\n    # --- Step 1: Generate all random indices for all bootstrap samples at once.\n    # This creates a 2D array of shape (n_bootstrap, subset_size), where each\n    # row is a set of indices for one bootstrap sample.\n    bootstrap_indices = np.random.choice(len(variable_values), size=(n_bootstrap, subset_size), replace=True)\n\n    # --- Step 2: Gather all bootstrap data samples using advanced indexing.\n    # This efficiently creates a 2D array of the actual data values for all samples.\n    bootstrap_data = variable_values[bootstrap_indices]\n\n    # --- Step 3: Vectorized calculation for best-of-N and worst-of-N.\n    # np.max/min along axis=1 finds the best/worst value within each sample.\n    # The result is a 1D array of shape (n_bootstrap,).\n    max_values_per_sample = np.max(bootstrap_data, axis=1)\n    min_values_per_sample = np.min(bootstrap_data, axis=1)\n\n    # Finally, calculate the mean and std across all bootstrap results.\n    metrics[f\"best@{subset_size}/mean\"] = np.mean(max_values_per_sample)\n    metrics[f\"best@{subset_size}/std\"] = np.std(max_values_per_sample)\n    metrics[f\"worst@{subset_size}/mean\"] = np.mean(min_values_per_sample)\n    metrics[f\"worst@{subset_size}/std\"] = np.std(min_values_per_sample)\n\n    # --- Step 4: Vectorized calculation for majority vote ('maj').\n    if \"pred\" in group.columns:\n        prediction_values = group[\"pred\"].to_numpy()\n        bootstrap_predictions = prediction_values[bootstrap_indices]\n\n        # Find the mode (most frequent prediction) for each bootstrap sample.\n        # `scipy.stats.mode` is vectorized and can operate along an axis.\n        modes_per_sample = mode(bootstrap_predictions, axis=1, keepdims=True)[0]\n\n        # To get the value associated with the majority vote, we find the *first*\n        # occurrence of the mode in each sample, replicating the original logic.\n        # `argmax` on the boolean mask provides the index of the first `True`.\n        mask = bootstrap_predictions == modes_per_sample\n        first_match_indices = np.argmax(mask, axis=1)\n\n        # Use the derived indices to gather the final majority vote values.\n        # This requires indexing the i-th row of `bootstrap_data` with the i-th index.\n        majority_values = bootstrap_data[np.arange(n_bootstrap), first_match_indices]\n\n        metrics[f\"maj@{subset_size}/mean\"] = np.mean(majority_values)\n        metrics[f\"maj@{subset_size}/std\"] = np.std(majority_values)\n\n    return metrics\n\n\n@ray.remote\ndef _process_prompt_group_task(group: pd.DataFrame, numeric_variables: List[str], seed: int) -> pd.DataFrame:\n    \"\"\"A Ray remote task to process metrics for a single prompt group.\n\n    This function serves as the parallel unit of work. It takes a DataFrame\n    for one prompt, calculates all standard and bootstrapped metrics, and\n    returns a tidy DataFrame of the results.\n\n    Args:\n        group: DataFrame containing all data for a single prompt.\n        numeric_variables: A list of column names to calculate metrics for.\n        seed: The random seed to ensure reproducible results for this task.\n\n    Returns:\n        A tidy DataFrame with columns ['data_source', 'prompt', 'var_name',\n        'metric_name', 'value'], containing all calculated metrics for the group.\n    \"\"\"\n    # Seed the random number generator for this specific worker.\n    np.random.seed(seed)\n\n    # Extract identifying information from the group.\n    data_source = group[\"data_source\"].iloc[0]\n    prompt = group[\"prompt\"].iloc[0]\n    num_responses = len(group)\n\n    # Store results in a list of dictionaries for efficient DataFrame creation.\n    results = []\n    for var_name in numeric_variables:\n        base_info = {\"data_source\": data_source, \"prompt\": prompt, \"var_name\": var_name}\n\n        # --- Calculate standard (non-bootstrapped) metrics ---\n        results.append({**base_info, \"metric_name\": f\"mean@{num_responses}\", \"value\": group[var_name].mean()})\n        \n        if num_responses > 1:\n            # 1. Re-added the original std@N metric for the user's logging block.\n            #    NOTE: Averaging this metric across prompts is statistically incorrect.\n            results.append({**base_info, \"metric_name\": f\"std@{num_responses}\", \"value\": group[var_name].std(ddof=1)})\n\n            # 2. Kept the components for the correct pooled standard deviation calculation.\n            #    These will be used for the function's actual return value.\n            variance = group[var_name].var(ddof=1)\n            df = num_responses - 1\n            sum_sq_dev = variance * df\n            results.append({**base_info, \"metric_name\": \"internal_sum_sq_dev_for_pooled_std\", \"value\": sum_sq_dev})\n            results.append({**base_info, \"metric_name\": \"internal_df_for_pooled_std\", \"value\": df})\n\n            # --- Calculate bootstrapped metrics for various sample sizes ---\n            bootstrap_sizes = sorted(list(set([2**i for i in range(1, 10) if 2**i < num_responses] + [num_responses])))\n\n            for size in bootstrap_sizes:\n                bootstrap_results = _calculate_bootstrap_metrics(group, var_name, subset_size=size)\n                for metric_name, value in bootstrap_results.items():\n                    results.append({**base_info, \"metric_name\": metric_name, \"value\": value})\n\n    return pd.DataFrame(results)\n\n\ndef bootstrap_metric(\n    data: list[Any],\n    subset_size: int,\n    reduce_fns: list[Callable[[np.ndarray], float]],\n    n_bootstrap: int = 1000,\n    seed: int = 42,\n) -> list[tuple[float, float]]:\n    \"\"\"\n    Performs bootstrap resampling to estimate statistics of metrics.\n\n    This function uses bootstrap resampling to estimate the mean and standard deviation\n    of metrics computed by the provided reduction functions on random subsets of the data.\n\n    Args:\n        data: List of data points to bootstrap from.\n        subset_size: Size of each bootstrap sample.\n        reduce_fns: List of functions that compute a metric from a subset of data.\n        n_bootstrap: Number of bootstrap iterations. Defaults to 1000.\n        seed: Random seed for reproducibility. Defaults to 42.\n\n    Returns:\n        A list of tuples, where each tuple contains (mean, std) for a metric\n        corresponding to each reduction function in reduce_fns.\n\n    Example:\n        >>> data = [1, 2, 3, 4, 5]\n        >>> reduce_fns = [np.mean, np.max]\n        >>> bootstrap_metric(data, 3, reduce_fns)\n        [(3.0, 0.5), (4.5, 0.3)]  # Example values\n    \"\"\"\n    np.random.seed(seed)\n\n    bootstrap_metric_lsts = [[] for _ in range(len(reduce_fns))]\n    for _ in range(n_bootstrap):\n        bootstrap_idxs = np.random.choice(len(data), size=subset_size, replace=True)\n        bootstrap_data = [data[i] for i in bootstrap_idxs]\n        for i, reduce_fn in enumerate(reduce_fns):\n            bootstrap_metric_lsts[i].append(reduce_fn(bootstrap_data))\n    return [(np.mean(lst), np.std(lst)) for lst in bootstrap_metric_lsts]\n\ndef calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> float:\n    \"\"\"\n    Calculate a value based on majority voting.\n\n    This function identifies the most common value for a specified vote key\n    in the data, then returns the corresponding value for that majority vote.\n\n    Args:\n        data: List of dictionaries, where each dictionary contains both vote_key and val_key.\n        vote_key: The key in each dictionary used for voting/counting.\n        val_key: The key in each dictionary whose value will be returned for the majority vote.\n\n    Returns:\n        The value associated with the most common vote.\n\n    Example:\n        >>> data = [\n        ...     {\"pred\": \"A\", \"val\": 0.9},\n        ...     {\"pred\": \"B\", \"val\": 0.8},\n        ...     {\"pred\": \"A\", \"val\": 0.7}\n        ... ]\n        >>> calc_maj_val(data, vote_key=\"pred\", val_key=\"val\")\n        0.9  # Returns the first \"val\" for the majority vote \"A\"\n    \"\"\"\n    vote2vals = defaultdict(list)\n    for d in data:\n        vote2vals[d[vote_key]].append(d[val_key])\n\n    vote2cnt = {k: len(v) for k, v in vote2vals.items()}\n    maj_vote = max(vote2cnt, key=vote2cnt.get)\n\n    maj_val = vote2vals[maj_vote][0]\n\n    return maj_val\n\n\ndef process_validation_metrics(\n    data_sources: list[str], sample_inputs: list[str], infos_dict: dict[str, list[Any]], sample_turns: list[int], seed: int = 42\n) -> dict[str, dict[str, dict[str, float]]]:\n    \"\"\"\n    Process validation metrics into a structured format with statistical analysis.\n\n    This function organizes validation metrics by data source and prompt, then computes\n    various statistical measures including means, standard deviations, best/worst values,\n    and majority voting results. It also performs bootstrap sampling to estimate statistics\n    for different sample sizes.\n\n    Args:\n        data_sources: List of data source identifiers for each sample.\n        sample_inputs: List of input prompts corresponding to each sample.\n        infos_dict: Dictionary mapping variable names to lists of values for each sample.\n        seed: Random seed for bootstrap sampling. Defaults to 42.\n\n    Returns:\n        A nested dictionary with the structure:\n        {\n            data_source: {\n                variable_name: {\n                    metric_name: value\n                }\n            }\n        }\n\n        Where metric_name includes:\n        - \"mean@N\": Mean value across N samples\n        - \"std@N\": Standard deviation across N samples\n        - \"best@N/mean\": Mean of the best values in bootstrap samples of size N\n        - \"best@N/std\": Standard deviation of the best values in bootstrap samples\n        - \"worst@N/mean\": Mean of the worst values in bootstrap samples\n        - \"worst@N/std\": Standard deviation of the worst values in bootstrap samples\n        - \"maj@N/mean\": Mean of majority voting results in bootstrap samples (if \"pred\" exists)\n        - \"maj@N/std\": Standard deviation of majority voting results (if \"pred\" exists)\n\n    Example:\n        >>> data_sources = [\"source1\", \"source1\", \"source2\"]\n        >>> sample_inputs = [\"prompt1\", \"prompt1\", \"prompt2\"]\n        >>> infos_dict = {\"score\": [0.8, 0.9, 0.7], \"pred\": [\"A\", \"A\", \"B\"]}\n        >>> result = process_validation_metrics(data_sources, sample_inputs, infos_dict)\n        >>> # result will contain statistics for each data source and variable\n    \"\"\"\n    # Group metrics by data source, prompt and variable\n    data_src2prompt2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))\n    for sample_idx, data_source in enumerate(data_sources):\n        prompt = sample_inputs[sample_idx]\n        var2vals = data_src2prompt2var2vals[data_source][prompt]\n        for var_name, var_vals in infos_dict.items():\n            var2vals[var_name].append(var_vals[sample_idx])\n\n    # Calculate metrics for each group\n    data_src2prompt2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))\n    for data_source, prompt2var2vals in data_src2prompt2var2vals.items():\n        for prompt, var2vals in prompt2var2vals.items():\n            for var_name, var_vals in var2vals.items():\n                if isinstance(var_vals[0], str):\n                    continue\n\n                metric = {}\n                n_resps = len(var_vals)\n                metric[f\"mean@{n_resps}\"] = np.mean(var_vals)\n\n                if n_resps > 1:\n                    metric[f\"std@{n_resps}\"] = np.std(var_vals)\n\n                    ns = []\n                    n = 2\n                    while n < n_resps:\n                        ns.append(n)\n                        n *= 2\n                    ns.append(n_resps)\n\n                    for n in ns:\n                        [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(\n                            data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed\n                        )\n                        metric[f\"best@{n}/mean\"], metric[f\"best@{n}/std\"] = bon_mean, bon_std\n                        metric[f\"worst@{n}/mean\"], metric[f\"worst@{n}/std\"] = won_mean, won_std\n                        if var2vals.get(\"pred\", None) is not None:\n                            vote_data = [{\"val\": val, \"pred\": pred} for val, pred in zip(var_vals, var2vals[\"pred\"])]\n                            [(maj_n_mean, maj_n_std)] = bootstrap_metric(\n                                data=vote_data,\n                                subset_size=n,\n                                reduce_fns=[partial(calc_maj_val, vote_key=\"pred\", val_key=\"val\")],\n                                seed=seed,\n                            )\n                            metric[f\"maj@{n}/mean\"], metric[f\"maj@{n}/std\"] = maj_n_mean, maj_n_std\n\n                data_src2prompt2var2metric[data_source][prompt][var_name] = metric\n\n    # Aggregate metrics across prompts\n    data_src2var2metric2prompt_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))\n    for data_source, prompt2var2metric in data_src2prompt2var2metric.items():\n        for prompt, var2metric in prompt2var2metric.items():\n            for var_name, metric in var2metric.items():\n                for metric_name, metric_val in metric.items():\n                    data_src2var2metric2prompt_vals[data_source][var_name][metric_name].append(metric_val)\n\n    data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))\n    for data_source, var2metric2prompt_vals in data_src2var2metric2prompt_vals.items():\n        for var_name, metric2prompt_vals in var2metric2prompt_vals.items():\n            for metric_name, prompt_vals in metric2prompt_vals.items():\n                data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(prompt_vals)\n    metric_dict = {}\n    for data_source, var2metric2val in data_src2var2metric2val.items():\n        core_var = \"acc\" if \"acc\" in var2metric2val else \"reward\"\n        for var_name, metric2val in var2metric2val.items():\n            n_max = max([int(name.split(\"@\")[-1].split(\"/\")[0]) for name in metric2val.keys()])\n            for metric_name, metric_val in metric2val.items():\n                if (\n                    (var_name == core_var)\n                    and any(metric_name.startswith(pfx) for pfx in [\"mean\", \"maj\", \"best\"])\n                    and (f\"@{n_max}\" in metric_name)\n                ):\n                    metric_sec = \"val-core\"\n                else:\n                    metric_sec = \"val-aux\"\n                pfx = f\"{metric_sec}/{data_source}/{var_name}/{metric_name}\"\n                metric_dict[pfx] = metric_val\n\n    if len(sample_turns) > 0:\n        sample_turns = np.concatenate(sample_turns)\n        metric_dict[\"val-aux/num_turns/min\"] = sample_turns.min()\n        metric_dict[\"val-aux/num_turns/max\"] = sample_turns.max()\n        metric_dict[\"val-aux/num_turns/mean\"] = sample_turns.mean()\n    \n    \n    # calculate test_score per data source\n    data_source_rewards = defaultdict(list)\n    for data_source, reward in zip(data_sources, infos_dict['reward']):\n        data_source_rewards[data_source].append(reward)\n\n    for source, rewards in data_source_rewards.items():\n        if rewards:\n            metric_dict[f\"val/test_score/{source}\"] = np.mean(rewards)\n    return metric_dict\n\n\ndef aggregate_validation_metrics(data_sources: List[str], sample_inputs: List[str], infos_dict: Dict[str, List[Any]], seed: int = 42) -> Dict[str, Dict[str, Dict[str, float]]]:\n    \"\"\"Process validation metrics into a structured format with statistical analysis.\n\n    This function organizes validation metrics by data source and prompt, then computes\n    various statistical measures including means, standard deviations, best/worst values,\n    and majority voting results. It also performs bootstrap sampling to estimate statistics\n    for different sample sizes.\n\n    Args:\n        data_sources: List of data source identifiers for each sample.\n        sample_inputs: List of input prompts corresponding to each sample.\n        infos_dict: Dictionary mapping variable names to lists of values for each sample.\n        seed: Random seed for bootstrap sampling. Defaults to 42.\n\n    Returns:\n        A nested dictionary with the structure:\n        {\n            data_source: {\n                variable_name: {\n                    metric_name: value\n                }\n            }\n        }\n\n        Where metric_name includes:\n        - \"mean@N\": Mean value across N samples\n        - \"std@N\": Standard deviation across N samples\n        - \"best@N/mean\": Mean of the best values in bootstrap samples of size N\n        - \"best@N/std\": Standard deviation of the best values in bootstrap samples\n        - \"worst@N/mean\": Mean of the worst values in bootstrap samples\n        - \"worst@N/std\": Standard deviation of the worst values in bootstrap samples\n        - \"maj@N/mean\": Mean of majority voting results in bootstrap samples (if \"pred\" exists)\n        - \"maj@N/std\": Standard deviation of majority voting results (if \"pred\" exists)\n\n    Example:\n        >>> data_sources = [\"source1\", \"source1\", \"source2\"]\n        >>> sample_inputs = [\"prompt1\", \"prompt1\", \"prompt2\"]\n        >>> infos_dict = {\"score\": [0.8, 0.9, 0.7], \"pred\": [\"A\", \"A\", \"B\"]}\n        >>> result = aggregate_validation_metrics(data_sources, sample_inputs, infos_dict)\n        >>> # result will contain statistics for each data source and variable\n    \"\"\"\n    # --- 1. Data Consolidation ---\n    # Combine all input lists into a single, unified DataFrame.\n    df = pd.DataFrame({\"data_source\": data_sources, \"prompt\": sample_inputs, **infos_dict})\n    numeric_vars = [col for col, dtype in df.dtypes.items() if pd.api.types.is_numeric_dtype(dtype)]\n\n    # --- 2. Task Preparation ---\n    # Split the DataFrame into a list of smaller DataFrames, one for each prompt group.\n    prompt_groups = [group for _, group in df.groupby([\"data_source\", \"prompt\"])]\n\n    # --- 3. Parallel Dispatch ---\n    # Launch all processing tasks concurrently. `ray.remote` returns immediately\n    # with a future (ObjectRef) for each task.\n    futures = [_process_prompt_group_task.remote(group, numeric_vars, int(seed)) for group in prompt_groups]\n\n    # --- 4. Result Collection ---\n    # `ray.get` blocks until all tasks are complete and retrieves their results.\n    processed_df_list = ray.get(futures)\n    \n    if not processed_df_list:\n        return {}\n    processed_df = pd.concat(processed_df_list)    \n\n    # --- 6. Final Aggregation ---\n    # Perform a single, efficient groupby to get the mean value of each metric\n    # across all prompts within a data source.\n    # Separate the standard metrics from the internal components for pooled std.\n    is_std_component = processed_df[\"metric_name\"].str.startswith(\"internal_\")\n    is_legacy_std = processed_df[\"metric_name\"].str.startswith(\"std@\")\n    \n    # Exclude internal components AND the legacy std@N metric from the regular aggregation.\n    regular_metrics_df = processed_df[~is_std_component & ~is_legacy_std]\n    std_components_df = processed_df[is_std_component]\n\n    # Aggregate regular metrics by taking the mean across all prompts.\n    final_agg_df = regular_metrics_df.groupby([\"data_source\", \"var_name\", \"metric_name\"])[\"value\"].mean().reset_index()\n\n    final_df = final_agg_df\n    # Calculate the pooled standard deviation correctly.\n    if not std_components_df.empty:\n        summed_components = std_components_df.groupby([\"data_source\", \"var_name\", \"metric_name\"])[\"value\"].sum().unstack()\n        total_df = summed_components[\"internal_df_for_pooled_std\"]\n        pooled_variance = summed_components[\"internal_sum_sq_dev_for_pooled_std\"].divide(total_df).fillna(0)\n        pooled_std = np.sqrt(pooled_variance)\n        \n        pooled_std_df = pooled_std.reset_index(name=\"value\")\n        pooled_std_df[\"metric_name\"] = \"pooled_std\"\n        \n        final_df = pd.concat([final_agg_df, pooled_std_df], ignore_index=True)\n\n    # --- 7. Output Formatting ---\n    # Convert the flattened Series from the groupby into the required nested dict.\n    output_dict = defaultdict(lambda: defaultdict(dict))\n    for _, row in final_df.iterrows():\n        output_dict[row['data_source']][row['var_name']][row['metric_name']] = row['value']\n\n    return output_dict\n"
  },
  {
    "path": "siirl/utils/model_utils/__init__.py",
    "content": "# Copyright (c) 2025, Shanghai Innovation Institute. All rights reserved.\n"
  },
  {
    "path": "siirl/utils/model_utils/activation_offload.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Functionality for CPU offloading of tensors saved for backward pass.\"\"\"\n\nfrom __future__ import annotations\n\nimport functools\nfrom typing import Any, Optional\n\nimport torch\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n\nfrom siirl.utils.extras.device import get_torch_device\nfrom siirl.utils.model_utils.fsdp_utils import FSDPModule as FSDP2\n\nfrom loguru import logger\n\n\ndef _get_unique_tensor_key(tensor):\n    key = (tensor.untyped_storage().data_ptr() + tensor.storage_offset(), tensor.dtype)\n    return key\n\n\nclass FSDPParameterFilter:\n    def __init__(self):\n        self.model_parameters_storage = set()\n\n    def __call__(self, tensor):\n        return tensor.untyped_storage().data_ptr() not in self.model_parameters_storage\n\n    def update_model_parameters(self, model):\n        new_storage = set()\n        for p in model.parameters():\n            new_storage.add(p.data.untyped_storage().data_ptr())\n        self.model_parameters_storage = new_storage\n\n\nclass CpuOffloadHookWithOffloadHandler:\n    \"\"\"Context-manager that offloads/recovers tensors through an offload hander.\n\n    The hook just offloads/recovers the tensor object to the handler through `tensor_push`\n    and `tensor_pop` interface. How the offload-handler manages the offloading, recovering\n    or prefetching timing is transparent to this hook.\n    \"\"\"\n\n    def __init__(\n        self,\n        offload_handler: OffloadHandler,\n        handler_extra_kwargs: Optional[dict[str, Any]] = None,\n    ) -> None:\n        if handler_extra_kwargs is None:\n            handler_extra_kwargs = {}\n        self.offload_handler: OffloadHandler = offload_handler\n        self.handler_extra_kwargs: dict[str, Any] = handler_extra_kwargs\n        self.inside_context = False\n\n    def __enter__(self):\n        self.inside_context = True\n        torch._C._autograd._push_saved_tensors_default_hooks(self.on_save_for_backward, self.on_get_saved_tensor)\n\n    def __exit__(self, *args: Any):\n        self.inside_context = False\n        torch._C._autograd._pop_saved_tensors_default_hooks()\n\n    def on_save_for_backward(self, tensor: torch.Tensor) -> Any:\n        retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs)\n        return retrieve_identifier\n\n    def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:\n        tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs)\n        return tensor\n\n\nclass OffloadHandler:\n    \"\"\"A base class for CPU offload-handler.\"\"\"\n\n    def __init__(self) -> None:\n        pass\n\n    def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:\n        \"\"\"Tensor push.\"\"\"\n        raise NotImplementedError(\"`tensor_push is not implented in OffloadHandler class. Inherit this class and implement your custom tensor_push.\")\n\n    def tensor_pop(self, tensor_tag: Any, **kwargs):\n        \"\"\"Tensor pop.\"\"\"\n        raise NotImplementedError(\"`tensor_pop is not implented in OffloadHandler class. Inherit this class and implement your custom tensor_pop.\")\n\n\nclass GroupCommitFunction(torch.autograd.Function):\n    \"\"\"this is a dummy op with output identical to input.\n    However, it is necessary for marking a timepoint for offload handler to\n    accomplish all synchronizations. Implementing it as a function is necessary\n    because we need to actions in both forward and backward.\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, tensor, cpu_offload_handler):\n        # pylint: disable=missing-function-docstring\n        cpu_offload_handler.on_group_commit_forward()\n        ctx.cpu_offload_handler = cpu_offload_handler\n        # return the identical tensor\n        return tensor\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        # pylint: disable=missing-function-docstring\n        cpu_offload_handler = ctx.cpu_offload_handler\n        cpu_offload_handler.on_group_commit_backward()\n        return grad_output, None\n\n\ngroup_prefetch_offload_commit = GroupCommitFunction.apply\n\n\nclass SynchronizedGroupOffloadHandler(OffloadHandler):\n    \"\"\"Offload Handler that offloads/reloads in a synchronized way.\n    The device-to-host and host-to-device copying happen in the same stream\n    as the computation kernels, thus the copying will block computation.\n    \"\"\"\n\n    def __init__(self, num_offload_group, tensor_need_offloading_checker=(lambda _: True)) -> None:\n        super().__init__()\n\n        self.num_offload_group = num_offload_group\n        self.tensor_need_offloading_checker = tensor_need_offloading_checker\n\n        self.groupid_reset()\n\n    def groupid_reset(self):\n        \"\"\"Groupid reset.\"\"\"\n        # Data structures to label saved tensors and book-keep their cpu copies.\n        # Currently, on push, create a new cpu tensor and copies; on pop, copies\n        # the tensor back to gpu and deletes the cpu tensor.\n        # These will increment whenever `group_commit()` is invoked\n        self.current_group, self.tensor_count_current_group = (0, 0)\n        self.torch_tensor_count = 0\n        self.tensor_tag_to_state = {}\n\n    def on_group_commit_forward(self):\n        \"\"\"On group commit forward.\"\"\"\n        # finishing up with updating current group and tensor count\n        self.current_group += 1  # increment\n        self.tensor_count_current_group = 0  # reset\n\n    def on_group_commit_backward(self):\n        \"\"\"On group commit backward.\"\"\"\n        self.current_group -= 1\n        assert self.current_group >= 0\n\n    @staticmethod\n    def offload(src_tensor, pin_memory=True):\n        \"\"\"Offload.\"\"\"\n\n        cpu_backup = torch.empty(\n            src_tensor.size(),\n            dtype=src_tensor.dtype,\n            layout=src_tensor.layout,\n            device=\"cpu\",\n            pin_memory=pin_memory,\n        )\n        cpu_backup.copy_(src_tensor, non_blocking=True)\n        state = (src_tensor.device, cpu_backup)\n        return state\n\n    @staticmethod\n    def reload(state, non_blocking=None):\n        \"\"\"Reload.\"\"\"\n        dev, cpu_backup = state\n        if non_blocking is None:\n            non_blocking = cpu_backup.is_pinned()\n        return cpu_backup.to(dev, non_blocking=non_blocking)\n\n    def tensor_push(self, tensor: torch.Tensor, **kwargs):\n        \"\"\"Tensor push.\"\"\"\n        # obtain a unique tensor tag\n        tensor_tag = (self.current_group, self.tensor_count_current_group)\n        self.tensor_count_current_group += 1\n        assert tensor_tag not in self.tensor_tag_to_state\n        if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(tensor):\n            state = SynchronizedGroupOffloadHandler.offload(tensor)\n            self.tensor_tag_to_state[tensor_tag] = state\n        else:\n            # will be offloaded together after group commit\n            self.tensor_tag_to_state[tensor_tag] = tensor\n\n        return tensor_tag\n\n    def tensor_pop(self, tensor_tag, **kwargs):\n        \"\"\"Tensor pop.\"\"\"\n        assert tensor_tag in self.tensor_tag_to_state\n        state = self.tensor_tag_to_state.pop(tensor_tag)\n        if isinstance(state, tuple):\n            tensor = SynchronizedGroupOffloadHandler.reload(state)\n        else:\n            tensor = state\n        return tensor\n\n\nclass AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):\n    \"\"\"Compared to synchronize, this uses more memory because of the buffer but\n    achieves better performance due to the overlapping. D2h and h2d copying are\n    completely hidden behind computation if computation time of a layer is longer\n    than host-device communication time. Bulk offloading with delay and bulk reloading\n    with prefetch are implemented.\"\"\"\n\n    def __init__(\n        self,\n        num_offload_group,  # must be <= actual number of groups (number of commits)\n        num_model_group,\n        tensor_need_offloading_checker=(lambda t: True),\n    ) -> None:\n        super().__init__(\n            num_offload_group=num_offload_group,\n            tensor_need_offloading_checker=tensor_need_offloading_checker,\n        )\n        # Number of layers in the model\n        self.num_layers = num_model_group\n        # Data Structure to maintain reference to activation tensors\n        self.tensor_tag_to_buf = {}\n        # Tracking the number of layers offloaded\n        self.offloaded_group_count = 0\n        # Core data structure that decides the window for offloading\n        self.layer_window_map = {}\n        self.group_offload_mapping = {}\n\n        # Logic to make offloading load balance across computation\n        # for optimal CPU/GPU interconnect usage\n        constant = 0\n        for i in range(self.num_offload_group):\n            self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1\n            if i < (self.num_layers % self.num_offload_group):\n                self.layer_window_map[i] += i + 1\n                constant = i + 1\n            else:\n                self.layer_window_map[i] += constant\n\n        # allocate streams and events for synchronization\n        self.d2h_stream = get_torch_device().Stream()\n        self.h2d_stream = get_torch_device().Stream()\n\n    def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:\n        torch_stray_tensor = isinstance(\n            tensor,\n            (\n                torch._subclasses.fake_tensor.FakeTensor,\n                torch._subclasses.functional_tensor.FunctionalTensor,\n            ),\n        )\n        need_offload = not torch_stray_tensor\n        need_offload = need_offload and self.tensor_need_offloading_checker(tensor)\n\n        if need_offload:\n            # obtain a unique tensor tag\n            tensor_tag = (self.current_group, self.tensor_count_current_group)\n            self.tensor_count_current_group += 1\n\n            assert tensor_tag not in self.tensor_tag_to_state\n            self.tensor_tag_to_state[tensor_tag] = tensor\n\n            if self.current_group < self.num_offload_group:\n                self.tensor_tag_to_buf[tensor_tag] = tensor\n        else:\n            tensor_tag = tensor\n        return tensor_tag\n\n    def tensor_pop(self, tensor_tag, **kwargs):\n        \"\"\"Tensor pop.\"\"\"\n        if isinstance(tensor_tag, torch.Tensor):\n            return tensor_tag\n        assert tensor_tag in self.tensor_tag_to_state\n        tensor = self.tensor_tag_to_state.pop(tensor_tag)\n        self.tensor_tag_to_buf.pop(tensor_tag, None)\n\n        # the tensor should have been copied back in on_group_commit_backward()\n        # which invokes bulk_reload_group.\n        assert not isinstance(tensor, tuple)\n        return tensor\n\n    def bulk_offload_group(self, group_to_offload):\n        \"\"\"Bulk offload group.\"\"\"\n        offload_mapping = {}\n        offload_size = 0\n        with get_torch_device().stream(self.d2h_stream):\n            for tensor_tag, state in self.tensor_tag_to_state.items():\n                group_id, _ = tensor_tag\n                if group_id == group_to_offload:\n                    assert not isinstance(state, tuple)\n                    key = _get_unique_tensor_key(state)\n                    if key not in offload_mapping:\n                        offload_mapping[key] = state\n                    # if offload, return the reference to cpu copy\n                    self.tensor_tag_to_state[tensor_tag] = (key, state.shape)\n            for key, tensor in offload_mapping.items():\n                state = SynchronizedGroupOffloadHandler.offload(tensor)\n                offload_size += tensor.numel() * tensor.element_size()\n                offload_mapping[key] = state\n\n            self.group_offload_mapping[group_to_offload] = offload_mapping\n\n    def synchronize_on_group_commit_forward(self, current_group):\n        \"\"\"Synchronize on group commit forward.\"\"\"\n\n        # For the first group, kickstart the offload after we have\n        # the first compute completion\n        if current_group == 0:\n            self.d2h_stream.wait_stream(get_torch_device().current_stream())\n            self.bulk_offload_group(current_group)\n\n        # Window map data structure helps us synchronize based on number\n        # of layers offloaded\n        if self.layer_window_map[self.offloaded_group_count] == current_group:\n            # Stream synchronization both ways\n            self.d2h_stream.wait_stream(get_torch_device().current_stream())\n            get_torch_device().current_stream().wait_stream(self.d2h_stream)\n\n            # Time to free the activation memory after usage\n            for tensor_tag, _ in self.tensor_tag_to_buf.items():\n                if tensor_tag[0] == self.offloaded_group_count:\n                    self.tensor_tag_to_buf[tensor_tag] = None\n\n            # Time to offload the next group\n            if self.offloaded_group_count < (self.num_offload_group - 1):\n                self.bulk_offload_group(self.offloaded_group_count + 1)\n\n            # Increment the offload group count to keep track\n            self.offloaded_group_count += 1\n\n    def on_group_commit_forward(self):\n        \"\"\"This function will cause host device synchronization\"\"\"\n        # handle synchronization events\n        self.synchronize_on_group_commit_forward(self.current_group)\n\n        super().on_group_commit_forward()\n\n    @torch.no_grad\n    def bulk_reload_group(self, group_to_reload):\n        \"\"\"Bulk reload group.\"\"\"\n        assert group_to_reload < self.num_offload_group\n\n        with get_torch_device().stream(self.h2d_stream):\n            # move back tensors\n            offload_mapping = self.group_offload_mapping.pop(group_to_reload)\n            assert offload_mapping is not None\n            for key, state in offload_mapping.items():\n                offload_mapping[key] = SynchronizedGroupOffloadHandler.reload(state)\n            for tensor_label, state in self.tensor_tag_to_state.items():\n                group_id, _ = tensor_label\n                if group_id == group_to_reload and not isinstance(state, torch.Tensor):\n                    assert isinstance(state, tuple), f\"{group_id} {state}\"\n                    key, shape = state\n                    recovered_tensor = offload_mapping[key].view(shape)\n                    self.tensor_tag_to_state[tensor_label] = recovered_tensor\n\n    def on_group_commit_backward(self):\n        # first decrement the current group.\n        # after last commit in forward, the group will +1; in backward it -1.\n        # Finally it should be decremented to 0.\n        self.current_group -= 1\n        assert self.current_group >= 0\n\n        # Layer window data structure helps us to reload at right times\n        if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group:\n            # Stream synchronization both ways\n            self.h2d_stream.wait_stream(get_torch_device().current_stream())\n            get_torch_device().current_stream().wait_stream(self.h2d_stream)\n\n            # Time to reload the next group\n            self.bulk_reload_group(self.offloaded_group_count - 1)\n\n            # Decrease the offloading group counter\n            self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0\n\n        # Last group computation needs to wait till all the reloads complete\n        if self.current_group == 0:\n            get_torch_device().current_stream().wait_stream(self.h2d_stream)\n            self.offloaded_group_count = 0\n\n\ndef get_activation_offload_context(num_layers: int = 1, model_layers: int = 1, tensor_need_offloading_checker=(lambda t: True)):\n    cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(\n        num_offload_group=num_layers,\n        num_model_group=model_layers,\n        tensor_need_offloading_checker=tensor_need_offloading_checker,\n    )\n\n    def group_prefetch_offload_commit_async(tensor):\n        return group_prefetch_offload_commit(tensor, cpu_offload_handler)\n\n    return (\n        CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler),\n        group_prefetch_offload_commit_async,\n    )\n\n\nclass ActivationHandler:\n    def __init__(self, offload_ctx, sync_func, tensor_filter, enable_ckpt):\n        self._offload_ctx = offload_ctx\n        self._sync_func = sync_func\n        self._enable_ckpt = enable_ckpt\n        self._tensor_filter = tensor_filter\n        if enable_ckpt:\n            self.checkpoint_fn = functools.partial(\n                torch.utils.checkpoint.checkpoint,\n                use_reentrant=True,\n            )\n\n    def pre_forward(self, module):\n        if module.training:\n            self._offload_ctx.__enter__()\n            self._tensor_filter.update_model_parameters(module)\n\n    def post_forward(self, module):\n        if module.training:\n            self._offload_ctx.__exit__(None, None, None)\n\n    def _pack_kwargs(self, *args, **kwargs):\n        kwarg_keys = []\n        flat_args = list(args)\n        for k, v in kwargs.items():\n            kwarg_keys.append(k)\n            flat_args.append(v)\n\n        return tuple(flat_args), tuple(kwarg_keys)\n\n    def _unpack_kwargs(self, flat_args, kwarg_keys):\n        assert len(kwarg_keys) <= len(flat_args), f\"too many keys {len(kwarg_keys)} vs. {len(flat_args)}\"\n        if len(kwarg_keys) == 0:\n            return flat_args, {}\n        args = flat_args[: -len(kwarg_keys)]\n        kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :]))\n        return args, kwargs\n\n    def _ckpt_forward(self, forward_method, *args, **kwargs):\n        flat_args, kwarg_keys = self._pack_kwargs(*args, **kwargs)\n\n        def my_function(*inputs):\n            # unpack back into args and kwargs\n            nonlocal forward_method, kwarg_keys\n            unpacked_args, unpacked_kwargs = self._unpack_kwargs(inputs, kwarg_keys)\n            # run original module\n            return forward_method(*unpacked_args, **unpacked_kwargs)\n\n        return self.checkpoint_fn(\n            my_function,\n            *flat_args,\n        )\n\n    def forward(self, module, forward_method, *args, **kwargs):\n        if not module.training:\n            return forward_method(*args, **kwargs)\n        if not self._enable_ckpt:\n            ret = forward_method(*args, **kwargs)\n        else:\n            ret = self._ckpt_forward(forward_method, *args, **kwargs)\n        binded_tensor = ret\n        if isinstance(ret, tuple):\n            binded_tensor = ret[0]\n        binded_tensor = self._sync_func(binded_tensor)\n        final_ret = binded_tensor\n        if isinstance(ret, tuple):\n            final_ret = (final_ret,) + ret[1:]\n        return final_ret\n\n    def wrap_module_forward_method(self, module):\n        orig_method = module.forward\n        handler = self\n\n        @functools.wraps(orig_method)\n        def wrapped_method(model_self, *args, **kwargs):\n            nonlocal handler\n            handler.pre_forward(model_self)\n            out = handler.forward(model_self, orig_method, *args, **kwargs)\n            handler.post_forward(model_self)\n            return out\n\n        module.forward = wrapped_method.__get__(module, type(module))\n\n\ndef enable_activation_offloading(model, strategy, enable_ckpt=False):\n    \"\"\"\n    Enable activation offloading for the model. It groups activations by TransformerLayer and offloads activation\n    groups asynchronously. This means that the offloading of the i-th activation group and the computation of the i+1-th\n    activation group happen at the same time, and there are at most two activation groups in GPU memory.\n\n    Args:\n        model: the model to enable activation offloading\n        strategy: the training strategy of the model, such as \"fsdp\"\n        enable_ckpt: whether activation checkpointing(also called gradient checkpointing) has been enabled for the model\n\n    Note:\n        For best efficiency, activation offloading is usually combined with activation checkpointing. However, this\n        implementation of activation offloading is conflicted with the implementation of activation checkpointing in\n        some training strategies. This function resolves this conflict, and therefore requires the \"strategy\" and\n        \"enable_ckpt\" arguments.\n\n    Returns:\n\n    \"\"\"\n\n    assert strategy == \"fsdp\" or strategy == \"fsdp2\", \"activation offloading only supports fsdp strategy\"\n    layers = []\n\n    def get_layers(module):\n        for name, child in module.named_children():\n            if not isinstance(child, (FSDP, FSDP2)):\n                get_layers(child)\n            else:\n                wrapped_module = child\n                if isinstance(child, FSDP):\n                    wrapped_module = child._fsdp_wrapped_module\n                # In some cases, torch.nn.Embedding is wrapped with FSDP alone. However, the activation\n                # size of torch.nn.Embedding is small, so it's not necessary to offload it.\n                if not isinstance(wrapped_module, torch.nn.Embedding):\n                    layers.append(child)\n\n    get_layers(model)\n    if len(layers) < 3:\n        logger.warning(f\"Find only {len(layers)} fsdp layers, not neccessary to enable async activation offloading\")\n        return\n\n    tensor_filter = FSDPParameterFilter()\n    context, sync_func = get_activation_offload_context(len(layers) - 1, len(layers), tensor_filter)\n    if enable_ckpt:\n        # The implementation of activation checkpointing in transformers library is incompatible with activation offloading,\n        # so it will be disabled, but this implementation supports another version of activation checkpointing, so that\n        # these two features can be enabled at the same time.\n        for module in model.modules():\n            if hasattr(module, \"gradient_checkpointing_disable\"):\n                module.gradient_checkpointing_disable()\n\n    handler = ActivationHandler(context, sync_func, tensor_filter, enable_ckpt)\n    for layer in layers:\n        module = layer\n        if isinstance(layer, FSDP):\n            module = module._fsdp_wrapped_module\n        handler.wrap_module_forward_method(module)\n"
  },
  {
    "path": "siirl/utils/model_utils/attention_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Callable\n\n_index_first_axis, _pad_input, _rearrange, _unpad_input = None, None, None, None\n\n\ndef _get_attention_functions() -> tuple[Callable, Callable, Callable, Callable]:\n    \"\"\"Dynamically import attention functions based on available hardware.\"\"\"\n\n    from siirl.utils.extras.device import is_cuda_available, is_npu_available\n\n    global _index_first_axis, _pad_input, _rearrange, _unpad_input\n\n    if is_cuda_available:\n        from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input\n    elif is_npu_available:\n        from siirl.utils.model_utils.npu_utils import index_first_axis, pad_input, rearrange, unpad_input\n\n    _index_first_axis, _pad_input, _rearrange, _unpad_input = index_first_axis, pad_input, rearrange, unpad_input\n\n    return _index_first_axis, _pad_input, _rearrange, _unpad_input\n\n\ndef index_first_axis(*args, **kwargs):\n    \"\"\"\n    Unified entry point for `index_first_axis` across CUDA and NPU backends.\n\n    Dynamically dispatches to the appropriate device-specific implementation:\n      - On CUDA: `flash_attn.bert_padding.index_first_axis`\n      - On NPU: `transformers.integrations.npu_flash_attention.index_first_axis`\n        (falls back to `transformers.modeling_flash_attention_utils._index_first_axis`\n        in newer versions of transformers).\n\n    Users can call this function directly without worrying about the underlying device.\n    \"\"\"\n    func, *_ = _get_attention_functions()\n    return func(*args, **kwargs)\n\n\ndef pad_input(*args, **kwargs):\n    \"\"\"\n    Unified entry point for `pad_input` across CUDA and NPU backends.\n\n    Dynamically dispatches to the appropriate device-specific implementation:\n      - On CUDA: `flash_attn.bert_padding.pad_input`\n      - On NPU: `transformers.integrations.npu_flash_attention.pad_input`\n        (falls back to `transformers.modeling_flash_attention_utils._pad_input`\n        in newer versions of transformers).\n\n    Users can call this function directly without worrying about the underlying device.\n    \"\"\"\n    _, func, *_ = _get_attention_functions()\n    return func(*args, **kwargs)\n\n\ndef rearrange(*args, **kwargs):\n    \"\"\"\n    Unified entry point for `rearrange` across CUDA and NPU backends.\n\n    Dynamically dispatches to the appropriate device-specific implementation:\n      - On CUDA: `flash_attn.bert_padding.rearrange`\n      - On NPU: `transformers.integrations.npu_flash_attention.rearrange`\n        (falls back to `einops.rearrange` if no dedicated NPU implementation exists).\n\n    Users can call this function directly without worrying about the underlying device.\n    \"\"\"\n    *_, func, _ = _get_attention_functions()\n    return func(*args, **kwargs)\n\n\ndef unpad_input(*args, **kwargs):\n    \"\"\"\n    Unified entry point for `unpad_input` across CUDA and NPU backends.\n\n    Dynamically dispatches to the appropriate device-specific implementation:\n      - On CUDA: `flash_attn.bert_padding.unpad_input`\n      - On NPU: `transformers.integrations.npu_flash_attention.unpad_input`\n        (falls back to `transformers.modeling_flash_attention_utils._unpad_input`\n        in newer versions of transformers).\n\n    Users can call this function directly without worrying about the underlying device.\n    \"\"\"\n    *_, func = _get_attention_functions()\n    return func(*args, **kwargs)\n\n\n__all__ = [\"index_first_axis\", \"pad_input\", \"rearrange\", \"unpad_input\"]\n"
  },
  {
    "path": "siirl/utils/model_utils/flops_counter.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom transformers import PretrainedConfig\n\nfrom siirl.utils.extras.device import get_torch_device\n\nVALID_CONFIG_TYPE = {\n    \"llama\",\n    \"qwen2\",\n    \"qwen2_vl\",\n    \"qwen2_5_vl\",\n    \"internvl_chat\",\n    \"qwen3\",\n    \"qwen3_moe\",\n    \"deepseek_v3\",\n    \"openvla\",\n    \"openvla-oft\",\n}\n\n\ndef get_device_flops(unit=\"T\"):\n    def unit_convert(number, level):\n        units = [\"B\", \"K\", \"M\", \"G\", \"T\", \"P\"]\n        if number <= 0:\n            return number\n        ptr = 0\n        while ptr < len(units) and units[ptr] != level:\n            number /= 1000\n            ptr += 1\n        return number\n\n    device_name = get_torch_device().get_device_name()\n    flops = float(\"inf\")  # INF flops for unkown gpu type\n\n    if \"MI300X\" in device_name:\n        flops = 1336e12\n    elif \"H100\" in device_name or \"H800\" in device_name or \"H200\" in device_name:\n        flops = 989e12\n    elif \"A100\" in device_name or \"A800\" in device_name:\n        flops = 312e12\n    elif \"L40\" in device_name:\n        flops = 181.05e12\n    elif \"L20\" in device_name:\n        flops = 119.5e12\n    elif \"H20\" in device_name:\n        flops = 148e12\n    elif \"910B\" in device_name:\n        flops = 354e12\n    flops_unit = unit_convert(flops, unit)\n    return flops_unit\n\n\nclass FlopsCounter:\n    \"\"\"\n    Used to count mfu during training loop\n\n    Example:\n        flops_counter = FlopsCounter(config)\n        flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)\n\n    \"\"\"\n\n    def __init__(self, config: PretrainedConfig, forward_only: bool = False):\n        if config.model_type not in VALID_CONFIG_TYPE:\n            print(\n                f\"Only support config type of {VALID_CONFIG_TYPE}, but got {config.model_type}. MFU will always be zero.\"\n            )\n\n        self.estimate_func = {\n            \"qwen2\": self._estimate_qwen2_flops,\n            \"llama\": self._estimate_qwen2_flops,\n            \"qwen2_vl\": self._estimate_qwen2_flops,\n            \"qwen2_5_vl\": self._estimate_qwen2_flops,\n            \"internvl_chat\": self._estimate_internvl_flops,\n            \"qwen3\": self._estimate_qwen2_flops,\n            \"qwen3_moe\": self._estimate_qwen3_moe_flops,\n            \"deepseek_v3\": self._estimate_deepseek_v3_flops,\n            \"openvla\": self._estimate_openvla_flops,\n            \"openvla-oft\": self._estimate_openvla_flops,\n        }\n        self.config = config\n        self.forward_only = forward_only\n        self.scaling_law_coff = 2 if self.forward_only else 6\n\n    def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time):\n        return 0\n\n    def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time):\n        hidden_size = self.config.hidden_size\n        vocab_size = self.config.vocab_size\n        num_hidden_layers = self.config.num_hidden_layers\n        num_key_value_heads = self.config.num_key_value_heads\n        num_attention_heads = self.config.num_attention_heads\n        intermediate_size = self.config.intermediate_size\n\n        head_dim = getattr(\n            self.config,\n            \"head_dim\",\n            self.config.hidden_size // self.config.num_attention_heads,\n        )\n        q_size = num_attention_heads * head_dim\n        k_size = num_key_value_heads * head_dim\n        v_size = num_key_value_heads * head_dim\n\n        # non-attn per layer parm\n        # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp\n        mlp_N = hidden_size * intermediate_size * 3\n        attn_linear_N = hidden_size * (\n            q_size + k_size + v_size + num_attention_heads * head_dim\n        )\n        emd_and_lm_head_N = vocab_size * hidden_size * 2\n        # non-attn all_layer parm\n        dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N\n        # non-attn all_layer & all_token fwd & bwd flops\n        dense_N_flops = self.scaling_law_coff * dense_N * tokens_sum\n\n        # attn all_layer & all_token fwd & bwd flops\n        seqlen_square_sum = 0\n        for seqlen in batch_seqlens:\n            seqlen_square_sum += seqlen * seqlen\n        attn_qkv_flops = (\n            12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers\n        )\n\n        # all_layer & all_token fwd & bwd flops\n        flops_all_token = dense_N_flops + attn_qkv_flops\n        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12\n        return flops_achieved\n\n    def _estimate_internvl_flops(self, tokens_sum, batch_seqlens, delta_time):\n        # TODO consider vit\n        hidden_size = self.config.llm_config.hidden_size\n        vocab_size = self.config.llm_config.vocab_size\n        num_hidden_layers = self.config.llm_config.num_hidden_layers\n        num_key_value_heads = self.config.llm_config.num_key_value_heads\n        num_attention_heads = self.config.llm_config.num_attention_heads\n        intermediate_size = self.config.llm_config.intermediate_size\n\n        head_dim = getattr(\n            self.config,\n            \"head_dim\",\n            self.config.hidden_size // self.config.num_attention_heads,\n        )\n        q_size = num_attention_heads * head_dim\n        k_size = num_key_value_heads * head_dim\n        v_size = num_key_value_heads * head_dim\n\n        # non-attn per layer parm\n        # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp\n        mlp_N = hidden_size * intermediate_size * 3\n        attn_linear_N = hidden_size * (\n            q_size + k_size + v_size + num_attention_heads * head_dim\n        )\n        emd_and_lm_head_N = vocab_size * hidden_size * 2\n        # non-attn all_layer parm\n        dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N\n        # non-attn all_layer & all_token fwd & bwd flops\n        dense_N_flops = 6 * dense_N * tokens_sum\n\n        # attn all_layer & all_token fwd & bwd flops\n        seqlen_square_sum = 0\n        for seqlen in batch_seqlens:\n            seqlen_square_sum += seqlen * seqlen\n        attn_qkv_flops = (\n            12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers\n        )\n\n        # all_layer & all_token fwd & bwd flops\n        flops_all_token = dense_N_flops + attn_qkv_flops\n        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12\n        return flops_achieved\n\n        # attn all_layer & all_token fwd & bwd flops\n        seqlen_square_sum = 0\n        for seqlen in batch_seqlens:\n            seqlen_square_sum += seqlen * seqlen\n        attn_qkv_flops = (\n            12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers\n        )\n\n        # all_layer & all_token fwd & bwd flops\n        flops_all_token = dense_N_flops + attn_qkv_flops\n        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12\n        return flops_achieved\n\n    def _estimate_deepseek_v3_flops(self, tokens_sum, batch_seqlens, delta_time):\n        hidden_size = self.config.hidden_size\n        vocab_size = self.config.vocab_size\n        moe_intermediate_size = self.config.moe_intermediate_size\n        num_hidden_layers = self.config.num_hidden_layers\n        first_k_dense_replace = self.config.first_k_dense_replace\n        num_query_heads = self.config.num_attention_heads\n        moe_num_expert = self.config.n_routed_experts\n\n        moe_topk = self.config.num_experts_per_tok\n        share_expert_num = self.config.n_shared_experts\n\n        # non-attn per layer parm\n        moe_gata_N = hidden_size * moe_num_expert\n        # moe has fc1_1, fc1_2 and fc2 using SwiGLU in ExpertMlp layer & shared experts\n        moe_expertmlp_N = (\n            hidden_size * moe_intermediate_size * (moe_topk + share_expert_num) * 3\n        )\n        # MLA attn\n        attn_linear_N = 0\n        q_head_dim = self.config.qk_nope_head_dim + self.config.qk_rope_head_dim\n        if self.config.q_lora_rank is None:\n            attn_linear_N += hidden_size * num_query_heads * q_head_dim\n        else:\n            attn_linear_N += hidden_size * self.config.q_lora_rank\n            attn_linear_N += num_query_heads * q_head_dim * self.config.q_lora_rank\n\n        attn_linear_N += hidden_size * (\n            self.config.kv_lora_rank + self.config.qk_rope_head_dim\n        )\n        attn_linear_N += (\n            num_query_heads\n            * (q_head_dim - self.config.qk_rope_head_dim + self.config.v_head_dim)\n            * self.config.kv_lora_rank\n        )\n        attn_linear_N += num_query_heads * self.config.v_head_dim * hidden_size\n        emd_and_lm_head_N = vocab_size * hidden_size * 2\n        # non-attn all_layer parm\n        moe_N = (\n            (moe_gata_N + moe_expertmlp_N + attn_linear_N)\n            * (num_hidden_layers - first_k_dense_replace)\n            + (hidden_size * self.config.intermediate_size * 3 + attn_linear_N)\n            * first_k_dense_replace\n            + emd_and_lm_head_N\n        )\n        # non-attn all_layer & all_token fwd & bwd flops\n        dense_N_flops = 6 * moe_N * tokens_sum\n\n        # attn all_layer & all_token fwd & bwd flops\n        seqlen_square_sum = 0\n        for seqlen in batch_seqlens:\n            seqlen_square_sum += seqlen * seqlen * num_hidden_layers\n\n        attn_qkv_flops = 12 * seqlen_square_sum * q_head_dim * num_query_heads\n        # all_layer & all_token fwd & bwk flops\n        flops_all_token = dense_N_flops + attn_qkv_flops\n        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12\n\n        return flops_achieved\n\n    def _estimate_qwen3_moe_flops(self, tokens_sum, batch_seqlens, delta_time):\n        hidden_size = self.config.hidden_size\n        vocab_size = self.config.vocab_size\n        num_hidden_layers = self.config.num_hidden_layers\n        num_key_value_heads = self.config.num_key_value_heads\n        num_attention_heads = self.config.num_attention_heads\n        moe_intermediate_size = self.config.moe_intermediate_size\n        moe_topk = self.config.num_experts_per_tok\n        num_experts = self.config.num_experts\n\n        head_dim = getattr(\n            self.config,\n            \"head_dim\",\n            self.config.hidden_size // self.config.num_attention_heads,\n        )\n        q_size = num_attention_heads * head_dim\n        k_size = num_key_value_heads * head_dim\n        v_size = num_key_value_heads * head_dim\n\n        # non-attn per layer parm\n        # gate + moe export\n        moe_mlp_N = (\n            hidden_size * moe_topk * moe_intermediate_size * 3\n            + hidden_size * num_experts\n        )\n        attn_linear_N = hidden_size * (\n            q_size + k_size + v_size + num_attention_heads * head_dim\n        )\n        emd_and_lm_head_N = vocab_size * hidden_size * 2\n        # non-attn all_layer parm\n        dense_N = (moe_mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N\n        # non-attn all_layer & all_token fwd & bwd flops\n        dense_N_flops = 6 * dense_N * tokens_sum\n\n        # attn all_layer & all_token fwd & bwd flops\n        seqlen_square_sum = 0\n        for seqlen in batch_seqlens:\n            seqlen_square_sum += seqlen * seqlen\n        attn_qkv_flops = (\n            12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers\n        )\n\n        # all_layer & all_token fwd & bwd flops\n        flops_all_token = dense_N_flops + attn_qkv_flops\n        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12\n        return flops_achieved\n\n    def _estimate_openvla_flops(self, tokens_sum, batch_seqlens, delta_time):\n        \"\"\"\n        Estimate FLOPs for OpenVLA/OpenVLA-OFT models.\n        \n        OpenVLA architecture:\n        - Vision encoder (frozen, not counted)\n        - Projector MLP (vision_dim -> llm_dim)\n        - Language model (typically Llama/Vicuna-based)\n        \n        The main computation comes from the language model backbone.\n        \"\"\"\n        # Access LLM config from text_config (nested config)\n        if not hasattr(self.config, 'text_config'):\n            # Fallback to zero if text_config is not available\n            print(\"Warning: OpenVLA config missing text_config, cannot estimate FLOPs\")\n            return 0\n        \n        llm_config = self.config.text_config\n        \n        # Extract LLM parameters\n        hidden_size = llm_config.hidden_size\n        vocab_size = llm_config.vocab_size\n        num_hidden_layers = llm_config.num_hidden_layers\n        num_key_value_heads = llm_config.num_key_value_heads\n        num_attention_heads = llm_config.num_attention_heads\n        intermediate_size = llm_config.intermediate_size\n\n        head_dim = getattr(\n            llm_config,\n            \"head_dim\",\n            llm_config.hidden_size // llm_config.num_attention_heads,\n        )\n        q_size = num_attention_heads * head_dim\n        k_size = num_key_value_heads * head_dim\n        v_size = num_key_value_heads * head_dim\n\n        # Language model FLOPs (similar to LLaMA)\n        # Llama uses SwiGLU, having up and down linear layers in MLP\n        mlp_N = hidden_size * intermediate_size * 3\n        attn_linear_N = hidden_size * (\n            q_size + k_size + v_size + num_attention_heads * head_dim\n        )\n        \n        # For OpenVLA, we use vocab_size from LLM config\n        # Note: OpenVLA uses action tokens, but they're part of the vocab\n        emd_and_lm_head_N = vocab_size * hidden_size * 2\n        \n        # Total dense parameters across all layers\n        dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N\n        \n        # Dense layer FLOPs (forward + backward)\n        dense_N_flops = self.scaling_law_coff * dense_N * tokens_sum\n\n        # Attention FLOPs\n        seqlen_square_sum = 0\n        for seqlen in batch_seqlens:\n            seqlen_square_sum += seqlen * seqlen\n        attn_qkv_flops = (\n            12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers\n        )\n\n        # Total FLOPs\n        flops_all_token = dense_N_flops + attn_qkv_flops\n        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12\n        return flops_achieved\n\n    def estimate_flops(self, batch_seqlens, delta_time):\n        \"\"\"\n        Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.\n\n        Args:\n            batch_seqlens (List[int] or int): A list where each element represents the number of valid tokens in the current batch,\n                or a single integer representing total tokens. Can also handle nested lists.\n            delta_time (float): The time taken to process the batch, in seconds.\n\n        Returns:\n            estimated_flops (float): The estimated FLOPS based on the input tokens and time.\n            promised_flops (float): The expected FLOPS of the current device.\n        \"\"\"\n        # Normalize batch_seqlens to a flat list of integers\n        def flatten_to_ints(data):\n            \"\"\"Recursively flatten nested lists/tuples to a flat list of integers.\"\"\"\n            if isinstance(data, (int, float)):\n                return [int(data)]\n            elif isinstance(data, (list, tuple)):\n                result = []\n                for item in data:\n                    result.extend(flatten_to_ints(item))\n                return result\n            else:\n                # If it's some other type (e.g., tensor), try to convert\n                try:\n                    return flatten_to_ints(list(data))\n                except:\n                    # Fallback: treat as single item\n                    return [int(data)]\n        \n        batch_seqlens_flat = flatten_to_ints(batch_seqlens)\n        tokens_sum = sum(batch_seqlens_flat)\n        \n        # Use the flattened list for further processing\n        batch_seqlens = batch_seqlens_flat\n        func = self.estimate_func.get(\n            self.config.model_type, self._estimate_unknown_flops\n        )\n        estimated_flops = func(tokens_sum, batch_seqlens, delta_time)\n        promised_flops = get_device_flops()\n        return estimated_flops, promised_flops\n"
  },
  {
    "path": "siirl/utils/model_utils/fsdp_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools\nimport itertools\nimport json\nimport math\nimport os\nfrom collections import OrderedDict\nfrom contextlib import contextmanager, nullcontext\nfrom typing import Dict\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nfrom packaging import version\nfrom torch.distributed import DeviceMesh\nfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP\nfrom torch.distributed.fsdp._runtime_utils import _lazy_init\nfrom torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy\nfrom transformers.trainer_pt_utils import get_module_class_from_name\nfrom loguru import logger\n\nfrom siirl.utils.extras.device import get_device_id, get_device_name, get_torch_device\n\nif version.parse(torch.__version__) >= version.parse(\"2.6\"):\n    from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard\nelif version.parse(torch.__version__) >= version.parse(\"2.4\"):\n    from torch.distributed._composable.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard\nelse:\n    fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy = None, None, None, None\n\n\ndef init_fn(x: torch.nn.Module):\n    if torch.distributed.get_rank() != 0:\n        x = x.to_empty(device=get_device_id(), recurse=False)\n        get_torch_device().empty_cache()\n    return x\n\n\ndef get_init_weight_context_manager(use_meta_tensor=True, mesh: DeviceMesh = None):\n    from accelerate import init_empty_weights\n\n    cpu_init_weights = lambda: torch.device(\"cpu\")\n    if use_meta_tensor:\n        if mesh is None:\n            init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights\n        else:\n            init_context = init_empty_weights if mesh.get_coordinate()[-1] != 0 else cpu_init_weights\n    else:\n        init_context = cpu_init_weights\n    return init_context\n\n\n# Copyright 2020-present the HuggingFace Inc. team.\n# Adapted from https://github.com/huggingface/transformers/src/transformers/trainer.py\ndef get_fsdp_wrap_policy_vla(module, config=None, is_lora=False):\n    \"\"\"\n    Get FSDP wrap policy specifically for VLA (Vision-Language-Action) models.\n    \n    VLA models have a three-component architecture:\n    1. Vision Backbone (VisionTransformer + Block)\n    2. Projector (PrismaticProjector) \n    3. Language Model (Transformer layers)\n    \n    Args:\n        module: The VLA module to get wrap policy for\n        config: Configuration for wrap policy (currently unused, for API compatibility)\n        is_lora: Whether to enable lambda policy for LoRA modules\n        \n    Returns:\n        FSDP auto wrap policy combining vision, projector, and language model policies\n    \"\"\"\n    from torch.distributed.fsdp.wrap import (\n        _module_wrap_policy, \n        _or_policy, \n        transformer_auto_wrap_policy, \n        lambda_auto_wrap_policy\n    )\n    \n    # Import VLA-specific classes\n    try:\n        from timm.models.vision_transformer import Block, VisionTransformer\n    except ImportError:\n        raise ImportError(\"timm is required for VLA models. Install with: pip install timm\")\n    \n    # 1. Vision Backbone Policy\n    vit_wrap_policy = functools.partial(\n        _module_wrap_policy, \n        module_classes={VisionTransformer}\n    )\n    transformer_block_policy = functools.partial(\n        transformer_auto_wrap_policy, \n        transformer_layer_cls={Block}\n    )\n    vision_fsdp_wrapping_policy = functools.partial(\n        _or_policy, \n        policies=[vit_wrap_policy, transformer_block_policy]\n    )\n    \n    # 2. Language Model Policy\n    # VLA models have nested structure: module.language_model contains the LLM\n    default_transformer_cls_names_to_wrap = getattr(\n        module.language_model, \"_no_split_modules\", None\n    )\n    \n    if default_transformer_cls_names_to_wrap is None:\n        raise ValueError(\n            \"Cannot find _no_split_modules in module.language_model. \"\n            \"Ensure the module is a valid VLA model with language_model attribute.\"\n        )\n    \n    transformer_cls_to_wrap = set()\n    for layer_class in default_transformer_cls_names_to_wrap:\n        logger.info(f\"VLA LLM layer class: {layer_class}\")\n        transformer_cls = get_module_class_from_name(module, layer_class)\n        if transformer_cls is None:\n            raise ValueError(\n                f\"Could not find transformer layer class '{layer_class}' in the model. \"\n                f\"Available modules: {[n for n, _ in module.named_modules()][:10]}...\"\n            )\n        transformer_cls_to_wrap.add(transformer_cls)\n    \n    llm_wrap_policy = functools.partial(\n        transformer_auto_wrap_policy,\n        transformer_layer_cls=transformer_cls_to_wrap,\n    )\n    logger.info(f\"VLA LLM wrap policy configured with {len(transformer_cls_to_wrap)} layer types\")\n    \n    # 3. Projector Policy\n    # Import PrismaticProjector - try both openvla and openvla-oft\n    PrismaticProjector = None\n    try:\n        from siirl.models.embodied.openvla_oft.modeling_prismatic import PrismaticProjector\n    except ImportError:\n        try:\n            from siirl.models.embodied.openvla.modeling_prismatic import PrismaticProjector\n        except ImportError:\n            raise ImportError(\n                \"Cannot import PrismaticProjector. Ensure VLA model files are available.\"\n            )\n    \n    prismatic_fsdp_wrapping_policy = functools.partial(\n        _module_wrap_policy,\n        module_classes={PrismaticProjector},\n    )\n    \n    # 4. Build combined policy list\n    vla_policies = [\n        vision_fsdp_wrapping_policy,\n        llm_wrap_policy,\n        prismatic_fsdp_wrapping_policy,\n    ]\n    \n    # 5. Add LoRA policy if needed\n    if is_lora:\n        def lambda_policy_fn(module):\n            return bool(\n                len(list(module.named_children())) == 0\n                and getattr(module, \"weight\", None) is not None\n                and module.weight.requires_grad\n            )\n        \n        lambda_policy = functools.partial(\n            lambda_auto_wrap_policy, \n            lambda_fn=lambda_policy_fn\n        )\n        vla_policies.append(lambda_policy)\n        logger.info(\"Added LoRA lambda policy for VLA model\")\n    \n    return functools.partial(_or_policy, policies=vla_policies)\n\n\ndef get_fsdp_wrap_policy(module, config=None, is_lora=False):\n    \"\"\"Get FSDP wrap policy for the module.\n    \n    Automatically detects model type and routes to appropriate policy:\n    - VLA models (OpenVLA, OpenVLA-OFT) → get_fsdp_wrap_policy_vla\n    - Standard models (LLMs, VLMs) → standard transformer policy\n\n    Args:\n        module: The module to get wrap policy for\n        config: Configuration for wrap policy\n        is_lora: Whether to enable lambda policy for LoRA modules\n    \"\"\"\n    # Auto-detect VLA models\n    is_embodied_model = (\n        hasattr(module, 'vision_backbone') \n        and hasattr(module, 'projector')\n        and hasattr(module, 'language_model')\n    )\n    \n    if is_embodied_model:\n        logger.info(\"🎯 Detected VLA model architecture, using specialized VLA wrap policy\")\n        return get_fsdp_wrap_policy_vla(module, config, is_lora)\n    \n    # Standard policy for non-VLA models\n    logger.info(\"📦 Using standard FSDP wrap policy\")\n    \n    if config is None:\n        config = {}\n\n    # NOTE: This is a temporary workaround to be compatible with the OmegaConf & dataclass. We will remove this once we have make all config in siirl from OmegaConf to data class.\n    def _get_attr(attr_name, default_value=None):\n        if hasattr(config, \"get\"):\n            return config.get(attr_name, default_value)\n        else:\n            return config.__getattribute__(attr_name)\n\n    if _get_attr(\"disable\", False):\n        return None\n\n    default_transformer_cls_names_to_wrap = getattr(module, \"_no_split_modules\", None)\n    fsdp_transformer_layer_cls_to_wrap = _get_attr(\"transformer_layer_cls_to_wrap\", default_transformer_cls_names_to_wrap)\n    min_num_params = _get_attr(\"min_num_params\", 0)\n    auto_wrap_policy = None\n\n    policies = []\n\n    from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy\n\n    # Add lambda policy for LoRA modules if is_lora is True\n    if is_lora:\n\n        def lambda_policy_fn(module):\n            return bool(len(list(module.named_children())) == 0 and getattr(module, \"weight\", None) is not None and module.weight.requires_grad)\n\n        lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)\n        policies.append(lambda_policy)\n\n    if min_num_params > 0:\n        size_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params)\n        policies.append(size_policy)\n    elif fsdp_transformer_layer_cls_to_wrap is not None:\n        transformer_cls_to_wrap = set()\n        for layer_class in fsdp_transformer_layer_cls_to_wrap:\n            transformer_cls = get_module_class_from_name(module, layer_class)\n            if transformer_cls is None:\n                logger.warning(f\"Could not find layer {layer_class}\")\n                # raise Exception(\"Could not find the transformer layer class to wrap in the model.\")\n            else:\n                transformer_cls_to_wrap.add(transformer_cls)\n\n        transformer_policy = functools.partial(\n            transformer_auto_wrap_policy,\n            transformer_layer_cls=transformer_cls_to_wrap,\n        )\n        policies.append(transformer_policy)\n\n    if len(policies) > 0:\n        auto_wrap_policy = functools.partial(_or_policy, policies=policies)\n\n    return auto_wrap_policy\n\n\n@torch.no_grad()\ndef offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True):\n    if fsdp_version(model) == 2:\n        offload_fsdp2_model_to_cpu(model, empty_cache)\n        return\n\n    assert isinstance(model, FSDP)\n    # lazy init FSDP model\n    _lazy_init(model, model)\n    assert model._is_root, \"Only support root model offloading to CPU\"\n    for handle in model._all_handles:\n        if handle._offload_params:\n            continue\n        flat_param = handle.flat_param\n        assert flat_param.data.data_ptr() == flat_param._local_shard.data_ptr() and id(flat_param.data) != id(flat_param._local_shard) and flat_param.data.size() == flat_param._local_shard.size()\n        handle.flat_param_to(torch.device(\"cpu\"), non_blocking=True)\n        # the following still keeps id(._local_shard) != id(.data)\n        flat_param._local_shard = flat_param.data\n        assert id(flat_param._local_shard) != id(flat_param.data)\n    if empty_cache:\n        get_torch_device().empty_cache()\n\n\n@torch.no_grad()\ndef offload_fsdp2_model_to_cpu(model, empty_cache: bool = True):\n    for param in model.parameters():\n        param.data = param.data.to(torch.device(\"cpu\"), non_blocking=True)\n    if empty_cache:\n        get_torch_device().empty_cache()\n\n\n@torch.no_grad()\ndef load_fsdp_model_to_gpu(model: FSDP):\n    if fsdp_version(model) == 2:\n        load_fsdp2_model_to_gpu(model)\n        return\n\n    assert isinstance(model, FSDP)\n    # lazy init FSDP model\n    _lazy_init(model, model)\n    assert model._is_root, \"Only support root model loading to GPU\"\n    device_id = get_device_id()\n    for handle in model._all_handles:\n        if handle._offload_params:\n            continue\n        flat_param = handle.flat_param\n        handle.flat_param_to(torch.device(f\"{get_device_name()}:{device_id}\"), non_blocking=True)\n        # the following still keeps id(._local_shard) != id(.data)\n        flat_param._local_shard = flat_param.data\n\n\n@torch.no_grad()\ndef load_fsdp2_model_to_gpu(model):\n    device = torch.cuda.current_device()\n    for param in model.parameters():\n        param.data = param.data.to(device, non_blocking=True)\n\n\n@torch.no_grad()\ndef offload_fsdp_optimizer(optimizer):\n    if not optimizer.state:\n        return\n    for param_group in optimizer.param_groups:\n        for param in param_group[\"params\"]:\n            state = optimizer.state[param]\n            for key, value in state.items():\n                if isinstance(value, torch.Tensor):\n                    state[key] = value.to(\"cpu\", non_blocking=True)\n\n\n@torch.no_grad()\ndef load_fsdp_optimizer(optimizer, device_id):\n    if not optimizer.state:\n        return\n    for param_group in optimizer.param_groups:\n        for param in param_group[\"params\"]:\n            state = optimizer.state[param]\n            for key, value in state.items():\n                if isinstance(value, torch.Tensor):\n                    state[key] = value.to(device_id, non_blocking=True)\n\n\n@contextmanager\ndef meta_device_init():\n    \"\"\"\n    Create model parameters with meta device.\n\n    Note buffers in model will still be initialized in default device (e.g., CPU),\n    since the buffers can be non-persistent and filled with expected values that can\n    NOT be captured in meta device.\n    \"\"\"\n    device = torch.device(\"meta\")\n    old_register_parameter = nn.Module.register_parameter\n    registered = set()\n\n    def register_empty_parameter(module, name, param):\n        old_register_parameter(module, name, param)\n        # we will skip register shared parameters as it\n        # is already registered previously\n        if param is not None and param not in registered:\n            param_cls = type(module._parameters[name])\n            kwargs = module._parameters[name].__dict__\n            kwargs[\"requires_grad\"] = param.requires_grad\n            module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)\n            registered.add(module._parameters[name])\n\n    try:\n        nn.Module.register_parameter = register_empty_parameter\n        yield\n    finally:\n        registered.clear()\n        nn.Module.register_parameter = old_register_parameter\n\n\ndef parallel_load_safetensors(filepath):\n    \"\"\"\n    Parallel load safetensors from huggingface checkpoint\n\n    Huggingface checkpoint contains:\n\n    - config.json: a json file for model configuration\n    - model.safetensor.index.json: a json file for safetensors (parameters & buffers) index\n    - model-000x-of-ooxx.safetensors: a binary file for safetensors (parameters & buffers) chunks\n\n    Or (when model is small),\n\n    - model.safetensors: a binary file for all parameters and buffers\n\n    Each rank will own a part of model chunks and load them directly into GPU memory.\n    \"\"\"\n    from safetensors.torch import load_file\n\n    safetensors2param = {}\n\n    index_file = os.path.join(filepath, \"model.safetensors.index.json\")\n    if os.path.exists(index_file):\n        index = json.load(open(index_file, \"rb\"))\n        for param_name, filename in index[\"weight_map\"].items():\n            safetensors2param.setdefault(filename, []).append(param_name)\n    else:\n        # in this case, the model is small and we can load it all at once\n        param_file = os.path.join(filepath, \"model.safetensors\")\n        assert os.path.exists(param_file), f\"Cannot find {param_file}\"\n        states = load_file(param_file)\n        for param_name in states:\n            safetensors2param.setdefault(\"model.safetensors\", []).append(param_name)\n        del states\n\n    total_files = len(safetensors2param)\n    ckpt_chunks = sorted(safetensors2param.keys())\n    world_size = dist.get_world_size()\n    size = int(math.ceil(total_files / world_size))\n    ckpt_chunks = [ckpt_chunks[rank * size : rank * size + size] for rank in range(world_size)]\n\n    shard_states = {}\n    device = get_device_id()\n    for rank, files in enumerate(ckpt_chunks):\n        if rank == dist.get_rank():\n            for file in files:\n                file = os.path.join(filepath, file)\n                states = load_file(file, device=device)\n                # print(f\"rank {rank} loading {file}...\")\n                shard_states.update(states)\n        else:\n            for file in files:\n                for param_name in safetensors2param[file]:\n                    shard_states[param_name] = rank\n    return shard_states\n\n\ndef parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, torch.nn.Parameter]):\n    \"\"\"\n    Generate a function to initialize sub-modules in the `module` with `shard_states`\n    from huggingface checkpoint.\n\n    Args:\n        module (torch.nn.Module): the global module to be initialized\n        shard_states (Dict[str, torch.nn.Parameter]): the shard states from huggingface checkpoint\n\n    Returns:\n        init_fn (Callable): a function to initialize sub-modules in the `module` with `shard_states`\n    \"\"\"\n\n    state2fqn = {}\n    for name, state in itertools.chain(module.named_parameters(remove_duplicate=False), module.named_buffers(remove_duplicate=False)):\n        state2fqn.setdefault(state, []).append(name)\n    # remove standalone parameters and buffers\n    shared = {s for s, names in state2fqn.items() if len(names) > 1}\n    materialized_states = {}\n\n    @torch.no_grad()\n    def create_and_sync_state(param_name, state, is_param):\n        assert param_name in shard_states, f\"{param_name} not loaded\"\n        device = get_device_id()\n        if is_param:\n            param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad)\n        else:  # buffer\n            param = torch.empty_like(state.data, device=device)\n        loaded = shard_states[param_name]\n        if isinstance(loaded, (torch.nn.Parameter, torch.Tensor)):\n            # NOTE: loaded.dtype can be different with param.dtype\n            param.data.copy_(loaded.data)\n            dist.broadcast(param.data, src=dist.get_rank())\n        else:\n            assert isinstance(loaded, int)  # the rank that holds the state\n            dist.broadcast(param.data, src=loaded)\n        shard_states.pop(param_name)\n        del loaded\n        return param\n\n    def init_fn(sub_mod: torch.nn.Module, recurse: bool = True):\n        param_and_buffers = tuple(sub_mod.named_parameters(recurse=False)) + tuple(sub_mod.named_buffers(recurse=False))\n        # param_and_buffers = sorted(sub_mod.named_parameters(recurse=False), key=lambda x: x[0])\n        for name, state in param_and_buffers:\n            if not state.is_meta:\n                continue\n            is_param = name in sub_mod._parameters\n            fqn = state2fqn[state].pop(0)\n            # non-persistent buffers will not be saved in state dict, we can safely skip it\n            if (not is_param) and fqn not in shard_states:\n                if state.is_meta:\n                    raise RuntimeError(f\"find a non-persistent buffer ({fqn}) initiated with device meta. Such buffer is not saved in checkpoint and user should guarantee to init in CPU / GPU device.\")\n                continue\n            # for shared parameter, we get it from the first time it is created\n            if state in shared:\n                if state not in materialized_states:\n                    materialized_states[state] = create_and_sync_state(fqn, state, is_param)\n                else:\n                    if fqn in shard_states:\n                        shard_states.pop(fqn)\n                materialize_state = materialized_states[state]\n            # for not shared parameter, we create it directly\n            else:\n                materialize_state = create_and_sync_state(fqn, state, is_param)\n            if is_param:\n                sub_mod._parameters[name] = materialize_state\n            else:\n                sub_mod._buffers[name] = materialize_state\n        if recurse:\n            for module in sub_mod.children():\n                init_fn(module, recurse=True)\n\n        # for debug\n        # if len(shard_states) == 0: print(\"clear\")\n        return sub_mod\n\n    return init_fn\n\n\ndef fsdp_version(model):\n    if isinstance(model, FSDP):\n        return 1\n    elif isinstance(model, FSDPModule):\n        return 2\n    else:\n        return 0\n\n\ndef get_fsdp_state_ctx(model, state_type, state_cfg, optim_cfg):\n    if fsdp_version(model) == 1:\n        return FSDP.state_dict_type(model, state_type, state_cfg, optim_cfg)\n    else:\n        return nullcontext()\n\n\ndef fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_mesh=None, cpu_offload=None):\n    \"\"\"\n    Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the\n    parameters from rank 0 to all other ranks. This function modifies the model in-place.\n\n    Args:\n        model (`torch.nn.Module`): The model to load the state dict into\n        full_state (`dict`): The full state dict to load, can only be on rank 0\n    \"\"\"\n    from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict\n\n    # To broadcast, it needs to be instantiated in the GPU.\n    if dist.get_rank() == 0:\n        model = model.to(device=torch.cuda.current_device(), non_blocking=True)\n    else:\n        model = model.to_empty(device=torch.cuda.current_device())\n\n    cpu_offload = cpu_offload is not None\n    options = StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload, broadcast_from_rank0=True)\n    set_model_state_dict(model, full_state, options=options)\n\n    # rotary_emb is not in state_dict, so we need to broadcast it manually\n    for name, buf in model.named_buffers():\n        dist.broadcast(buf, src=0)\n\n    if cpu_offload:\n        model.to(\"cpu\", non_blocking=True)\n        for buf in model.buffers():\n            buf.data = buf.data.to(torch.cuda.current_device())\n\n\ndef apply_fsdp2(model, fsdp_kwargs, config):\n    \"\"\"model: AutoModelForCausalLM\"\"\"\n    assert CPUOffloadPolicy is not None, \"PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)\"\n\n    default_transformer_cls_names_to_wrap = getattr(model, \"_no_split_modules\", None)\n    fsdp_transformer_layer_cls_to_wrap = config.get(\"wrap_policy\", {}).get(\"transformer_layer_cls_to_wrap\", default_transformer_cls_names_to_wrap)\n\n    if isinstance(fsdp_transformer_layer_cls_to_wrap, str):\n        fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap]\n\n    assert len(fsdp_transformer_layer_cls_to_wrap) > 0 and fsdp_transformer_layer_cls_to_wrap[0] is not None\n\n    modules = []\n    for name, module in model.named_modules():\n        if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or (isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings):\n            modules.append(module)\n\n    for idx, module in enumerate(modules):\n        fully_shard(module, **fsdp_kwargs)\n    fully_shard(model, **fsdp_kwargs)  # fsdp2 will not reshard_after_forward for root module\n\n\ndef fsdp2_clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None):\n    \"\"\"torch.nn.utils.clip_grad_norm_ cann't run on cpu parameter DTensor\"\"\"\n    from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm\n\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n    else:\n        # prevent generators from being exhausted\n        parameters = list(parameters)\n    grads = [p.grad for p in parameters if p.grad is not None]\n    total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)\n    total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True)\n    _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)\n    return total_norm\n\n\ndef layered_summon_lora_params(fsdp_module) -> OrderedDict:\n    from peft.utils.save_and_load import get_peft_model_state_dict\n\n    def __prefix_submodules(module, prefix):\n        for name, submodule in module.named_modules():\n            if name.startswith(prefix) and \".\" not in name[len(prefix) :]:\n                yield name, submodule\n\n    lora_params = OrderedDict()\n    prefix_list = [\n        # fsdp\n        \"_fsdp_wrapped_module.base_model.model.\",\n        \"_fsdp_wrapped_module.base_model.model.model.\",\n        \"_fsdp_wrapped_module.base_model.model.model.layers.\",\n        # fsdp2\n        \"base_model.model.\",\n        \"base_model.model.model.\",\n        \"base_model.model.model.layers.\",\n    ]\n    peft_model = getattr(fsdp_module, \"_fsdp_wrapped_module\", fsdp_module)\n    for prefix in prefix_list:\n        for name, submodule in __prefix_submodules(fsdp_module, prefix):\n            prefix = name.replace(\"_fsdp_wrapped_module.base_model.model.\", \"base_model.model.\")\n            if name.endswith(\".model\") or name.endswith(\".layers\"):\n                continue\n            if fsdp_version(submodule) > 0:\n                with FSDP.summon_full_params(submodule, writeback=False):\n                    sub_lora_params = get_peft_model_state_dict(peft_model, state_dict=submodule.state_dict())\n                    sub_lora_params = {f\"{prefix}.{name}\": param.full_tensor().detach().cpu() if hasattr(param, \"full_tensor\") else param.detach().cpu() for name, param in sub_lora_params.items()}\n                    lora_params.update(sub_lora_params)\n                    submodule._is_root = False\n                get_torch_device().empty_cache()   \n    return lora_params\n"
  },
  {
    "path": "siirl/utils/model_utils/model.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nUtilities to create common models from huggingface\n\"\"\"\n\nimport os\nimport re\nimport warnings\nfrom dataclasses import dataclass\nfrom typing import Dict, Optional, Type\n\nimport numpy as np\nimport torch\nfrom loguru import logger\nfrom torch import nn\nfrom transformers import (\n    AutoConfig,\n    AutoModelForCausalLM,\n    GenerationConfig,\n    MistralForSequenceClassification,\n    PretrainedConfig,\n    PreTrainedModel,\n)\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\nfrom siirl.models.registry import ModelRegistry\n\n\nclass LambdaLayer(nn.Module):\n    def __init__(self, fn):\n        super().__init__()\n        self.fn = fn\n\n    def forward(self, *args, **kwargs):\n        return self.fn(*args, **kwargs)\n\n\ndef squeeze(x):\n    return torch.squeeze(x, dim=-1)\n\n\ndef update_model_config(module_config, override_config_kwargs):\n    \"\"\"Update the module config with the override_config_kwargs.\n    Args:\n        module_config: The module config from Huggingface Transformers.\n        override_config_kwargs: The kwargs to override the module config.\n    \"\"\"\n    for key, val in override_config_kwargs.items():\n        if isinstance(val, dict):\n            update_model_config(getattr(module_config, key), val)\n        else:\n            setattr(module_config, key, val)\n\n\ndef get_huggingface_actor_config(model_name: str, override_config_kwargs=None, trust_remote_code=False) -> Dict:\n    if override_config_kwargs is None:\n        override_config_kwargs = {}\n    assert isinstance(override_config_kwargs, Dict), (\n        f\"override_config_kwargs must be a dict, got {type(override_config_kwargs)}\"\n    )\n    module_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)\n    update_model_config(module_config, override_config_kwargs)\n\n    return module_config\n\n\ndef get_generation_config(\n    model: str,\n    trust_remote_code: bool = False,\n) -> Optional[GenerationConfig]:\n    try:\n        return GenerationConfig.from_pretrained(model)\n    except OSError:  # Not found\n        try:\n            config = get_huggingface_actor_config(\n                model,\n                trust_remote_code=trust_remote_code,\n            )\n            return GenerationConfig.from_model_config(config)\n        except OSError:  # Not found\n            return None\n\n\ndef create_huggingface_actor(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module:\n    \"\"\"\n\n    Args:\n        model_name:\n        override_config_kwargs:\n\n    Returns:\n\n    \"\"\"\n    if override_config_kwargs is None:\n        override_config_kwargs = {}\n    if automodel_kwargs is None:\n        automodel_kwargs = {}\n    assert isinstance(override_config_kwargs, Dict), (\n        f\"override_config_kwargs must be a dict, got {type(override_config_kwargs)}\"\n    )\n    module_config = get_huggingface_actor_config(\n        model_name, override_config_kwargs, trust_remote_code=automodel_kwargs.get(\"trust_remote_code\", False)\n    )\n    module: nn.Module = AutoModelForCausalLM.from_config(module_config, **automodel_kwargs)\n    return module\n\n\ndef create_huggingface_critic(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module:\n    \"\"\"\n\n    Args:\n        model_name:\n        override_config_kwargs:\n\n    Returns:\n\n    \"\"\"\n    critic_module: nn.Module = create_huggingface_actor(\n        model_name, override_config_kwargs=override_config_kwargs, automodel_kwargs=automodel_kwargs\n    )\n    if automodel_kwargs is None:\n        automodel_kwargs = {}\n    torch_dtype = automodel_kwargs.get(\"torch_dtype\", torch.float32)\n    critic_module.lm_head = nn.Sequential(\n        nn.Linear(critic_module.config.hidden_size, 1, dtype=torch_dtype), LambdaLayer(fn=squeeze)\n    )\n    return critic_module\n\n\ndef get_model_size(model: nn.Module, scale=\"auto\"):\n    n_params = sum(p.numel() for p in model.parameters())\n\n    if scale == \"auto\":\n        if n_params > 1e9:\n            scale = \"B\"\n        elif n_params > 1e6:\n            scale = \"M\"\n        elif n_params > 1e3:\n            scale = \"K\"\n        else:\n            scale = \"\"\n\n    if scale == \"B\":\n        n_params = n_params / 1e9\n    elif scale == \"M\":\n        n_params = n_params / 1e6\n    elif scale == \"K\":\n        n_params = n_params / 1e3\n    elif scale == \"\":\n        pass\n    else:\n        raise NotImplementedError(f\"Unknown scale {scale}\")\n\n    return n_params, scale\n\n\ndef print_model_size(model: nn.Module, name: str = None):\n    n_params, scale = get_model_size(model, scale=\"auto\")\n    if name is None:\n        name = model.__class__.__name__\n    logger.info(f\"{name} contains {n_params:.2f}{scale} parameters\")\n\n\ndef create_random_mask(\n    input_ids: torch.Tensor,\n    max_ratio_of_valid_token: float,\n    max_ratio_of_left_padding: float,\n    min_ratio_of_valid_token: float = 0,\n):\n    \"\"\"Create a random mask given input_ids. Support left padding and right padding.\n    Process:\n    - Sample valid token length\n    - Sample left_padding length\n    - Generate padding\n\n    Args:\n        input_ids:\n            shape (batch_size, seq_len)\n\n    Returns:\n\n    \"\"\"\n    assert max_ratio_of_valid_token > 0 and max_ratio_of_valid_token <= 1.0\n    assert max_ratio_of_left_padding >= 0 and max_ratio_of_left_padding < 1.0\n    assert min_ratio_of_valid_token <= max_ratio_of_valid_token\n\n    batch_size, sequence_length = input_ids.shape\n    max_num_valid_tokens = int(sequence_length * max_ratio_of_valid_token)\n    min_num_valid_tokens = max(1, int(sequence_length * min_ratio_of_valid_token))\n    max_left_padding = int(sequence_length * max_ratio_of_left_padding)\n    assert max_num_valid_tokens + max_left_padding <= sequence_length\n    assert max_num_valid_tokens > 0 and max_ratio_of_valid_token <= sequence_length\n    masks = torch.ones_like(input_ids, dtype=torch.int64)\n    # TODO: we can make this faster\n    for i in range(batch_size):\n        num_left_padding = np.random.randint(low=0, high=max_left_padding + 1, dtype=np.int64)\n        num_valid = np.random.randint(low=min_num_valid_tokens, high=max_num_valid_tokens + 1, dtype=np.int64)\n\n        for index in range(num_left_padding):\n            masks[i, index] = 0\n\n        for index in range(num_left_padding + num_valid, sequence_length):\n            masks[i, index] = 0\n    return masks\n\n\ndef compute_position_id_with_mask(mask):\n    return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)\n\n\ndef normalize_model_name(name, pp_rank, vpp_rank, transformer_config, layer_name=\"layers\"):\n    \"\"\"\n    Transform the model name in each model_chunk in each pp stage into the name in inference engine\n    \"\"\"\n    from siirl.utils.megatron.megatron_utils import get_transformer_layer_offset\n\n    layer_offset = get_transformer_layer_offset(pp_rank, vpp_rank, transformer_config)\n\n    if layer_name in name:  # belong to an intermediate layer\n        split_name = name.split(\".\")\n        # find the num next to split_name\n        for i, name in enumerate(split_name):\n            if name == layer_name:\n                break\n        layer_num_idx = i + 1\n        # check the name\n        assert len(split_name) >= layer_num_idx + 1, f\"split_name = {split_name}\"\n        assert split_name[layer_num_idx].isdigit(), f\"split_name = {split_name}\"\n        # increment layer_num_idx by layer_offset\n        split_name[layer_num_idx] = str(int(split_name[layer_num_idx]) + layer_offset)\n        name = \".\".join(split_name)  # weight name in inference_tp_model\n    return name\n\n\ndef normalize_pp_vpp_params(params, num_hidden_layers, layer_name=\"layers\"):\n    \"\"\"\n    Normalize the pp vpp params into a complete named parameters.\n    This is useful when gather parameters from pp ranks and passed to a model without pp\n\n    params: Iterable[List[Dict[str, param]]]\n        params contains a list of pp, with a list of vpp named_parameters in each vpp chunk.\n    output: Dict[str, param]\n\n    \"\"\"\n    pp_size = len(params)\n    for pp_rank in range(len(params)):\n        vpp_size = len(params[pp_rank])\n        for vpp_rank in range(vpp_size):\n            for name, param in params[pp_rank][vpp_rank].items():\n                normalized_name = normalize_model_name(\n                    name, pp_rank, vpp_rank, pp_size, vpp_size, num_hidden_layers, layer_name=layer_name\n                )\n                yield normalized_name, param\n\n\ndef get_parallel_model_from_config(\n    config, megatron_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False\n):\n    from megatron.core import ModelParallelConfig\n\n    assert isinstance(megatron_config, ModelParallelConfig)\n    model_class = _get_parallel_model_architecture_from_config(config, value)\n\n    model = model_class(\n        config,\n        megatron_config,\n        pre_process=pre_process,\n        post_process=post_process,\n        share_embeddings_and_output_weights=share_embeddings_and_output_weights,\n    )\n    return model\n\n\ndef _get_parallel_model_architecture_from_config(config: PretrainedConfig, value=False) -> Type[nn.Module]:\n    architectures = getattr(config, \"architectures\", [])\n    for arch in architectures:\n        model_cls = ModelRegistry.load_model_cls(arch, value)\n        logger.info(\"after load model cls\")\n        if model_cls is not None:\n            return model_cls\n    raise ValueError(\n        f\"Model architectures {architectures} are not supported for now. Supported architectures: \"\n        f\"{ModelRegistry.get_supported_archs()}\"\n    )\n\n\ndef _load_hf_model(config, model_config, is_value_model, local_cache_path):\n    \"\"\"Helper function containing the loading hf model logic\"\"\"\n    from accelerate import init_empty_weights\n    from megatron.core import parallel_state as mpu\n\n    from siirl.models.mcore.saver import _megatron_calc_global_rank\n\n    assert hasattr(model_config, \"architectures\"), \"architectures cannot be empty when load weight!\"\n    architectures = getattr(model_config, \"architectures\", [])\n    local_cache_path = os.path.expanduser(local_cache_path)\n\n    if config.model.path.startswith(\"hdfs:\"):\n        from siirl.utils.extras.fs import copy_to_local\n\n        logger.info(f\"start download from {config.model.path}\")\n        local_model_path = copy_to_local(\n            src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get(\"use_shm\", False)\n        )\n        logger.info(\"finish download\")\n    else:\n        local_model_path = config.model.path\n        logger.info(f\"load from local dir {local_model_path}\")\n\n    src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=mpu.get_context_parallel_rank())\n    cpu_init_weights = lambda: torch.device(\"cpu\")\n    init_context = init_empty_weights if torch.distributed.get_rank() != src_rank else cpu_init_weights\n    with init_context(), warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\")\n        # TODO: to find a better way to load mistral7b-rm lm_head\n        if \"mistral7b-rm\" in config.model.path:\n            model = MistralForSequenceClassification.from_pretrained(\n                local_model_path,\n                torch_dtype=\"auto\",\n                # device_map=\"auto\",  # disable auto device_map, the HF weight is only loaded to CPU in src_rank\n                # low_cpu_mem_usage=True\n            )  # use score head instead of lm_head\n            state_dict = model.state_dict()\n            state_dict[\"lm_head.weight\"] = state_dict[\"score.weight\"]\n            state_dict[\"model.embed_tokens.weight\"] = state_dict[\"model.embed_tokens.weight\"][\n                :32000\n            ]  # workaround, 32001 -> 32000\n            is_value_model = True\n        else:\n            model = AutoModelForCausalLM.from_pretrained(\n                local_model_path,\n                torch_dtype=\"auto\",\n                # device_map=\"auto\", # disable auto device_map, the HF weight is only loaded to CPU in src_rank\n                # low_cpu_mem_usage=True\n            )\n            state_dict = model.state_dict()\n\n    return architectures, model, state_dict, is_value_model\n\n\ndef get_hf_model_path(config, local_cache_path=\"~/.cache/siirl/rlhf\"):\n    local_cache_path = os.path.expanduser(local_cache_path)\n    if config.model.path.startswith(\"hdfs:\"):\n        from siirl.utils.extras.fs import copy_to_local\n\n        local_model_path = copy_to_local(\n            src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.use_shm\n        )\n    else:\n        local_model_path = config.model.path\n    return local_model_path\n\n\ndef load_megatron_model_weights(\n    config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path=\"~/.cache/siirl/rlhf\"\n):\n    \"\"\"Load weights for siirl customized model.\"\"\"\n    architectures, model, state_dict, is_value_model = _load_hf_model(\n        config, model_config, is_value_model, local_cache_path\n    )\n\n    from siirl.models.weight_loader_registry import get_weight_loader\n\n    logger.info(f\"before weight loader: architectures = {architectures}...\")\n    for arch in architectures:\n        logger.info(f\"call weight loader arch = {arch}, model config = {model.config}\")\n        weight_loader = get_weight_loader(arch)\n        weight_loader(\n            state_dict=state_dict,\n            wrapped_models=parallel_model,\n            config=model.config,\n            params_dtype=params_dtype,\n            is_value_model=is_value_model,\n            tie_word_embeddings=model_config.tie_word_embeddings,\n        )\n    return model.config\n\n\ndef load_megatron_gptmodel_weights(\n    config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path=\"~/.cache/siirl/rlhf\"\n):\n    \"\"\"Load weights for mcore GPT model.\"\"\"\n    _, model, state_dict, is_value_model = _load_hf_model(config, model_config, is_value_model, local_cache_path)\n\n    from siirl.models.mcore.loader import load_state_dict_to_megatron_gptmodel\n\n    load_state_dict_to_megatron_gptmodel(\n        state_dict=state_dict,\n        wrapped_models=parallel_model,\n        config=model.config,\n        params_dtype=params_dtype,\n        is_value_model=is_value_model,\n    )\n    del state_dict, model\n\n\n# pad input_ids_rmpad, cu_seqlens and max_seqlen_in_batch to be divisible by tp\ndef pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batch, size):\n    \"\"\"pad the tokens such that the total length is a multiple of size.\n    This function is useful when applying sequence parallel and context parallel\n\n    Args:\n        unpad_tokens: (total_nnz, ...). Tokens after removing padding\n        cu_seqlens: (total_nnz + 1,)\n        max_seqlen_in_batch: int\n\n    Returns:\n\n    \"\"\"\n    F = nn.functional\n\n    total_nnz = unpad_tokens.shape[0]\n\n    pad_size = 0 if total_nnz % size == 0 else size - total_nnz % size\n\n    # we assume adding a new data in the batch with seqlen pad_size\n    if pad_size > 0:\n        if unpad_tokens.ndim == 1:\n            unpad_tokens = F.pad(unpad_tokens, (0, pad_size))\n        elif unpad_tokens.ndim == 2:\n            unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size))\n        else:\n            raise NotImplementedError(f\"Padding dim {unpad_tokens.ndim()} is not supported\")\n\n        cu_seqlens = F.pad(cu_seqlens, (0, 1), value=pad_size + cu_seqlens[-1])\n        max_seqlen_in_batch = max(max_seqlen_in_batch, pad_size)\n\n    return unpad_tokens, cu_seqlens, max_seqlen_in_batch\n\n\ndef load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=False):\n    from megatron.core import dist_checkpointing\n    from megatron.core.dist_checkpointing.serialization import StrictHandling\n    from megatron.core.models.gpt.gpt_model import GPTModel\n\n    # strict = StrictHandling.IGNORE_ALL if is_value_model else StrictHandling.ASSUME_OK_UNEXPECTED\n    strict = StrictHandling.ASSUME_OK_UNEXPECTED\n    for model in parallel_model:\n        if isinstance(model.module, GPTModel):\n            ssd = model.module.sharded_state_dict()\n        else:\n            ssd = model.module.module.sharded_state_dict()\n        if is_value_model:\n            for k in list(ssd.keys()):\n                if \"output_layer\" in k:\n                    ssd.pop(k)\n        dist_checkpointing.load(ssd, dist_weight_path, strict=strict)\n\n    return\n\n\ndef get_parallel_gptmodel_from_config(\n    tfconfig, hf_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False\n):\n    from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec\n    from megatron.core.models.gpt.gpt_model import GPTModel\n\n    use_te = True\n    assert tfconfig.normalization == \"RMSNorm\", \"only RMSNorm is supported for now\"\n    transformer_layer_spec = get_gpt_decoder_block_spec(tfconfig, use_transformer_engine=use_te)\n    rope_scaling_args = {}\n    if hf_config.rope_scaling is not None:\n        assert hf_config.rope_scaling[\"type\"] == \"linear\", \"only linear scaling is supported for now\"\n        rope_scaling_args[\"seq_len_interpolation_factor\"] = hf_config.rope_scaling[\"factor\"]\n    parallel_model = GPTModel(\n        config=tfconfig,\n        transformer_layer_spec=transformer_layer_spec,\n        vocab_size=hf_config.vocab_size,\n        max_sequence_length=hf_config.max_position_embeddings,\n        pre_process=pre_process,\n        post_process=post_process,\n        share_embeddings_and_output_weights=share_embeddings_and_output_weights,\n        position_embedding_type=\"rope\",\n        rotary_base=hf_config.rope_theta,\n        **rope_scaling_args,\n    )\n    # # for layer in parallel_model.decoder.layers:\n    # layer.self_attention.core_attention.flash_attention.softmax_scale = None\n    if post_process and value:\n        from siirl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer\n\n        parallel_model.output_layer = LinearForLastLayer(\n            input_size=tfconfig.hidden_size, output_size=1, config=tfconfig\n        )\n    return parallel_model\n\n\ndef convert_weight_keys(state_dict: Dict[str, torch.Tensor], model: PreTrainedModel):\n    # convert state dict keys: https://github.com/huggingface/transformers/pull/38385\n    if not hasattr(model, \"_checkpoint_conversion_mapping\"):\n        return state_dict\n\n    reverse_key_mapping = {v: k for k, v in model._checkpoint_conversion_mapping.items()}\n    original_weights = {}\n    for key, value in state_dict.items():\n        for pattern, replacement in reverse_key_mapping.items():\n            replacement = replacement.lstrip(\"^\")  # strip off un-needed chars and patterns\n            replacement = re.sub(r\"\\(.*\\)\", \"\", replacement)\n            key, n_replace = re.subn(pattern, replacement, key)\n            # Early exit of the loop\n            if n_replace > 0:\n                break\n\n        original_weights[key] = value\n\n    return original_weights\n\n\ndef extract_multi_modal_inputs(\n    batch_data: list[dict[str, torch.Tensor]],\n    indices: Optional[list[int]] = None,\n) -> dict[str, torch.Tensor | list[torch.Tensor]]:\n    \"\"\"\n    Extract and process multi-modal inputs from a batch.\n\n    Args:\n        batch_data (list[dict[str, torch.Tensor]]): The batch containing potential multi-modal inputs\n        indices (Optional[list[int]]): If provided, only extract inputs at these indices\n\n    Returns:\n        dict[str, torch.Tensor | list[torch.Tensor]]: Processed multi-modal inputs ready for model consumption\n\n    \"\"\"\n    multi_modal_inputs = {}\n    multi_modal_inputs_collected = {}\n    has_image_bound = False\n\n    selected_batch_data = batch_data\n    if indices is not None:\n        selected_batch_data = [batch_data[i] for i in indices if i < len(batch_data)]\n\n    for inputs in selected_batch_data:\n        if \"image_bound\" in inputs:\n            has_image_bound = True\n        for key, value in inputs.items():\n            if value is not None:\n                if key not in multi_modal_inputs_collected:\n                    multi_modal_inputs_collected[key] = []\n                multi_modal_inputs_collected[key].append(value)\n\n    for key, values in multi_modal_inputs_collected.items():\n        if has_image_bound:  # minicpm-o logic\n            multi_modal_inputs[key] = values\n        else:\n            multi_modal_inputs[key] = torch.cat(values, dim=0)\n\n    return multi_modal_inputs\n\n\n@dataclass\nclass CausalLMOutputForPPO(CausalLMOutputWithPast):\n    log_probs: Optional[torch.FloatTensor] = None\n    entropy: Optional[torch.FloatTensor] = None\n"
  },
  {
    "path": "siirl/utils/model_utils/npu_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\n\n\n# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py\nclass IndexFirstAxis(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input, indices):\n        ctx.save_for_backward(indices)\n        assert input.ndim >= 2\n        ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]\n        second_dim = other_shape.numel()\n        # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.\n        # return input[indices]\n        return torch.gather(rearrange(input, \"b ... -> b (...)\"), 0, repeat(indices, \"z -> z d\", d=second_dim)).reshape(\n            -1, *other_shape\n        )\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        (indices,) = ctx.saved_tensors\n        assert grad_output.ndim >= 2\n        other_shape = grad_output.shape[1:]\n        grad_output = rearrange(grad_output, \"b ... -> b (...)\")\n        grad_input = torch.zeros(\n            [ctx.first_axis_dim, grad_output.shape[1]],\n            device=grad_output.device,\n            dtype=grad_output.dtype,\n        )\n        # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.\n        # grad_input[indices] = grad_output\n        grad_input.scatter_(0, repeat(indices, \"z -> z d\", d=grad_output.shape[1]), grad_output)\n        return grad_input.reshape(ctx.first_axis_dim, *other_shape), None\n\n\nindex_first_axis = IndexFirstAxis.apply\n\n\n# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py\nclass IndexPutFirstAxis(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, values, indices, first_axis_dim):\n        ctx.save_for_backward(indices)\n        assert indices.ndim == 1\n        assert values.ndim >= 2\n        output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)\n        # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.\n        output[indices] = values\n        # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        (indices,) = ctx.saved_tensors\n        # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.\n        grad_values = grad_output[indices]\n        # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))\n        return grad_values, None, None\n\n\nindex_put_first_axis = IndexPutFirstAxis.apply\n\n\n# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py\ndef pad_input(hidden_states, indices, batch, seqlen):\n    \"\"\"\n    Arguments:\n        hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.\n        indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.\n        batch: int, batch size for the padded sequence.\n        seqlen: int, maximum sequence length for the padded sequence.\n    Return:\n        hidden_states: (batch, seqlen, ...)\n    \"\"\"\n    # dim = hidden_states.shape[-1]\n    # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)\n    # output[indices] = hidden_states\n    output = index_put_first_axis(hidden_states, indices, batch * seqlen)\n    return rearrange(output, \"(b s) ... -> b s ...\", b=batch)\n\n\n# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py\ndef unpad_input(hidden_states, attention_mask, unused_mask=None):\n    \"\"\"\n    Arguments:\n        hidden_states: (batch, seqlen, ...)\n        attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.\n        unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.\n    Return:\n        hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.\n        indices: (total_nnz), the indices of masked tokens from the flattened input sequence.\n        cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.\n        max_seqlen_in_batch: int\n        seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.\n    \"\"\"\n    all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask\n    seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)\n    used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))\n    # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the\n    # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim\n    # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to\n    # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,\n    # so we write custom forward and backward to make it a bit faster.\n    return (\n        index_first_axis(rearrange(hidden_states, \"b s ... -> (b s) ...\"), indices),\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n        used_seqlens_in_batch,\n    )\n"
  },
  {
    "path": "siirl/utils/model_utils/seqlen_balancing.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport copy\nimport heapq\nfrom itertools import chain\nfrom typing import List, Tuple\n\nimport torch\nfrom torch import distributed as dist\n\nimport siirl.utils.model_utils.tensordict_utils as tu\nfrom siirl.utils.extras.device import get_device_name\nfrom tensordict import TensorDict\n\n\n\ndef calculate_workload(seqlen_list: list[int]):\n    \"\"\"\n    Calculate the workload for a dense transformer block based on sequence length.\n    FLOPs = 12 * hidden_size^2 * seqlen + 2 * hidden_size * seqlen^2\n    Hardcodes the constants by a 7B model (hidden_size=4096),\n    so the FLOPs are propotional to (6 * 4096 * seqlen + seqlen^2).\n    \"\"\"\n    if not isinstance(seqlen_list, torch.Tensor):\n        seqlen_list = torch.tensor(seqlen_list, dtype=torch.int64)\n    return 24576 * seqlen_list + seqlen_list**2\n\n\ndef karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool):\n    # see: https://en.wikipedia.org/wiki/Largest_differencing_method\n    class Set:\n        def __init__(self) -> None:\n            self.sum = 0\n            self.items = []\n\n        def add(self, idx: int, val: int):\n            self.items.append((idx, val))\n            self.sum += val\n\n        def merge(self, other):\n            for idx, val in other.items:\n                self.items.append((idx, val))\n                self.sum += val\n\n        def __lt__(self, other):\n            if self.sum != other.sum:\n                return self.sum < other.sum\n            if len(self.items) != len(other.items):\n                return len(self.items) < len(other.items)\n            return self.items < other.items\n\n    class State:\n        def __init__(self, items: list[tuple[int, int]], k: int) -> None:\n            self.k = k\n            # sets should always be decreasing order\n            self.sets = [Set() for _ in range(k)]\n            assert len(items) in [1, k], f\"{len(items)} not in [1, {k}]\"\n            for i, (idx, seqlen) in enumerate(items):\n                self.sets[i].add(idx=idx, val=seqlen)\n            self.sets = sorted(self.sets, reverse=True)\n\n        def get_partitions(self):\n            partitions = []\n            for i in range(len(self.sets)):\n                cur_partition = []\n                for idx, _ in self.sets[i].items:\n                    cur_partition.append(idx)\n                partitions.append(cur_partition)\n            return partitions\n\n        def merge(self, other):\n            for i in range(self.k):\n                self.sets[i].merge(other.sets[self.k - 1 - i])\n            self.sets = sorted(self.sets, reverse=True)\n\n        @property\n        def spread(self) -> int:\n            return self.sets[0].sum - self.sets[-1].sum\n\n        def __lt__(self, other):\n            # least heap, let the state with largest spread to be popped first,\n            # if the spread is the same, let the state who has the largest set\n            # to be popped first.\n            if self.spread != other.spread:\n                return self.spread > other.spread\n            return self.sets[0] > other.sets[0]\n\n        def __repr__(self) -> str:\n            repr_str = \"[\"\n            for i in range(self.k):\n                if i > 0:\n                    repr_str += \",\"\n                repr_str += \"{\"\n                for j, (_, seqlen) in enumerate(self.sets[i].items):\n                    if j > 0:\n                        repr_str += \",\"\n                    repr_str += str(seqlen)\n                repr_str += \"}\"\n            repr_str += \"]\"\n            return repr_str\n\n    sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)])\n    states_pq = []\n    if equal_size:\n        assert len(seqlen_list) % k_partitions == 0, f\"{len(seqlen_list)} % {k_partitions} != 0\"\n        for offset in range(0, len(sorted_seqlen_list), k_partitions):\n            items = []\n            for i in range(k_partitions):\n                seqlen, idx = sorted_seqlen_list[offset + i]\n                items.append((idx, seqlen))\n            heapq.heappush(states_pq, State(items=items, k=k_partitions))\n    else:\n        for seqlen, idx in sorted_seqlen_list:\n            heapq.heappush(states_pq, State(items=[(idx, seqlen)], k=k_partitions))\n\n    while len(states_pq) > 1:\n        state0 = heapq.heappop(states_pq)\n        state1 = heapq.heappop(states_pq)\n        # merge states\n        state0.merge(state1)\n        heapq.heappush(states_pq, state0)\n\n    final_state = states_pq[0]\n    partitions = final_state.get_partitions()\n    if equal_size:\n        for i, partition in enumerate(partitions):\n            assert len(partition) * k_partitions == len(seqlen_list), (\n                f\"{len(partition)} * {k_partitions} != {len(seqlen_list)}\"\n            )\n    return partitions\n\n\ndef greedy_partition(seqlen_list: list[int], k_partitions: int, equal_size: bool):\n    bias = sum(seqlen_list) + 1 if equal_size else 0\n    sorted_seqlen = [(seqlen + bias, i) for i, seqlen in enumerate(seqlen_list)]\n    partitions = [[] for _ in range(k_partitions)]\n    partition_sums = [0 for _ in range(k_partitions)]\n    for seqlen, i in sorted_seqlen:\n        min_idx = None\n        for j in range(k_partitions):\n            if min_idx is None or partition_sums[j] < partition_sums[min_idx]:\n                min_idx = j\n        partitions[min_idx].append(i)\n        partition_sums[min_idx] += seqlen\n    if equal_size:\n        for i, partition in enumerate(partitions):\n            assert len(partition) * k_partitions == len(seqlen_list), (\n                f\"{len(partition)} * {k_partitions} != {len(seqlen_list)}\"\n            )\n    return partitions\n\n\ndef get_seqlen_balanced_partitions(seqlen_list: list[int], k_partitions: int, equal_size: bool):\n    \"\"\"\n    Calculates partitions of indices from seqlen_list such that the sum of sequence lengths\n    in each partition is balanced. Uses the Karmarkar-Karp differencing method.\n\n    This is useful for balancing workload across devices or batches, especially when\n    dealing with variable sequence lengths.\n\n    Args:\n        seqlen_list (List[int]): A list of sequence lengths for each item.\n        k_partitions (int): The desired number of partitions.\n        equal_size (bool): If True, ensures that each partition has the same number of items.\n                           Requires len(seqlen_list) to be divisible by k_partitions.\n                           If False, partitions can have varying numbers of items, focusing\n                           only on balancing the sum of sequence lengths.\n\n    Returns:\n        List[List[int]]: A list containing k_partitions lists. Each inner list contains the\n                         original indices of the items assigned to that partition. The indices\n                         within each partition list are sorted.\n\n    Raises:\n        AssertionError: If len(seqlen_list) < k_partitions.\n        AssertionError: If equal_size is True and len(seqlen_list) is not divisible by k_partitions.\n        AssertionError: If any resulting partition is empty.\n    \"\"\"\n    assert len(seqlen_list) >= k_partitions, f\"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]\"\n\n    def _check_and_sort_partitions(partitions):\n        assert len(partitions) == k_partitions, f\"{len(partitions)} != {k_partitions}\"\n        seen_idx = set()\n        sorted_partitions = [None] * k_partitions\n        for i, partition in enumerate(partitions):\n            assert len(partition) > 0, f\"the {i}-th partition is empty\"\n            for idx in partition:\n                seen_idx.add(idx)\n            sorted_partitions[i] = sorted(partition)\n        assert seen_idx == set(range(len(seqlen_list)))\n        return sorted_partitions\n\n    partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size)\n    return _check_and_sort_partitions(partitions)\n\n\ndef log_seqlen_unbalance(seqlen_list: list[int], partitions: list[list[int]], prefix):\n    \"\"\"\n    Calculate and log metrics related to sequence length imbalance before and after partitioning.\n\n    Args:\n        seqlen_list (List[int]): A list of sequence lengths for each item.\n        partitions (List[List[int]]): A list of partitions, where each inner list contains indices\n                                      from seqlen_list assigned to that partition.\n        prefix (str): A prefix to be added to each metric key in the returned dictionary.\n\n    Returns:\n        dict: A dictionary containing metrics related to sequence length imbalance.\n    \"\"\"\n    # Get the number of partitions\n    k_partition = len(partitions)\n    # assert len(seqlen_list) % k_partition == 0\n    batch_size = len(seqlen_list) // k_partition\n    min_sum_seqlen = None\n    max_sum_seqlen = None\n    total_sum_seqlen = 0\n\n    # Iterate over each batch of sequence lengths\n    for offset in range(0, len(seqlen_list), batch_size):\n        cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size])\n        if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen:\n            min_sum_seqlen = cur_sum_seqlen\n        if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen:\n            max_sum_seqlen = cur_sum_seqlen\n        total_sum_seqlen += cur_sum_seqlen\n\n    balanced_sum_seqlen_list = []\n    for partition in partitions:\n        cur_sum_seqlen_balanced = sum([seqlen_list[i] for i in partition])\n        balanced_sum_seqlen_list.append(cur_sum_seqlen_balanced)\n    # print(\"balanced_sum_seqlen_list: \", balanced_sum_seqlen_list)\n    min_sum_seqlen_balanced = min(balanced_sum_seqlen_list)\n    max_sum_seqlen_balanced = max(balanced_sum_seqlen_list)\n\n    return {\n        f\"{prefix}/min\": min_sum_seqlen,\n        f\"{prefix}/max\": max_sum_seqlen,\n        f\"{prefix}/minmax_diff\": max_sum_seqlen - min_sum_seqlen,\n        f\"{prefix}/balanced_min\": min_sum_seqlen_balanced,\n        f\"{prefix}/balanced_max\": max_sum_seqlen_balanced,\n        f\"{prefix}/mean\": total_sum_seqlen / len(partitions),\n    }\n\n\ndef ceildiv(a, b):\n    return -(a // -b)\n\n\ndef roundup_divisible(a, b):\n    return ((a + b - 1) // b) * b\n\n\ndef rearrange_micro_batches(\n    batch,\n    max_token_len,\n    dp_group=None,\n    num_batches_divided_by=None,\n    same_micro_num_in_dp=True,\n    min_num_micro_batch=None,\n    use_dynamic_bsz_balance=True,\n):\n    \"\"\"\n    Split a batch into micro-batches by total token count, with optional DP sync and padding.\n\n    Args:\n        batch (TensorDict): must include \"attention_mask\" (B*S); other fields are sliced similarly.\n        max_token_len (int): max sum of attention_mask per micro-batch.\n        dp_group (optional): torch.distributed group for data-parallel sync.\n        num_batches_divided_by (optional): virtual pipeline parallel size, for megatron.\n        same_micro_num_in_dp (bool): if True and dp_group set, pad all ranks to the same count.\n        min_num_micro_batch (int, optional): force at least this many splits (pads empty ones).\n        use_dynamic_bsz_balance (bool, optional): balance the computational workload between micro-batches\n\n    Returns:\n        List[TensorDict]: the micro-batches.\n        List[List[int]]: index lists mapping each micro-batch back to original positions.\n    \"\"\"\n    # this is per local micro_bsz\n    input_ids = batch[\"input_ids\"]\n    if input_ids.is_nested:\n        seq_len_effective: torch.Tensor = input_ids.offsets().diff()\n        max_seq_len = max(seq_len_effective)\n    else:\n        max_seq_len = batch[\"attention_mask\"].shape[-1]\n        seq_len_effective: torch.Tensor = batch[\"attention_mask\"].sum(dim=1)\n\n    assert max_token_len >= max_seq_len, (\n        f\"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}\"\n    )\n    total_seqlen = seq_len_effective.sum().item()\n    # NOTE: num_microbatches <= batch_size, so take the min of this two.\n    num_micro_batches = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len))\n    if min_num_micro_batch is not None:\n        # used to support pp\n        num_micro_batches = max(min_num_micro_batch, num_micro_batches)\n    if dist.is_initialized() and same_micro_num_in_dp:\n        num_micro_batches = torch.tensor([num_micro_batches], device=get_device_name())\n        dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group)\n        num_micro_batches = num_micro_batches.cpu().item()\n    if num_batches_divided_by is not None:\n        num_micro_batches = roundup_divisible(num_micro_batches, num_batches_divided_by)\n\n    assert num_micro_batches <= len(seq_len_effective)\n\n    workloads = calculate_workload(seq_len_effective)\n    micro_bsz_idx = get_seqlen_balanced_partitions(workloads, num_micro_batches, equal_size=False)\n\n    if use_dynamic_bsz_balance:\n        # Use the sum of squared sequence lengths to approximate attention computation workload\n        micro_bsz_idx.sort(\n            key=lambda partition: (\n                sum(workloads[idx] for idx in partition),\n                partition[0] if partition else 0,\n            ),\n            reverse=True,\n        )\n        # Place smaller micro-batches at both ends to reduce the bubbles exposed during the warm-up and cool-down.\n        micro_bsz_idx = micro_bsz_idx[::2][::-1] + micro_bsz_idx[1::2]\n\n    micro_batches = []\n\n    for partition in micro_bsz_idx:\n        curr_micro_batch = tu.index_select_tensor_dict(batch, partition)\n        micro_batches.append(curr_micro_batch)\n\n    return micro_batches, micro_bsz_idx\n\n\ndef get_reverse_idx(idx_map):\n    \"\"\"\n    Build the inverse of an index mapping.\n\n    Args:\n        idx_map (Sequence[int]): Sequence where idx_map[i] = j.\n\n    Returns:\n        List[int]: Inverse mapping list such that output[j] = i for each i.\n    \"\"\"\n    reverse_idx_map = copy.deepcopy(idx_map)\n\n    for i, idx in enumerate(idx_map):\n        reverse_idx_map[idx] = i\n\n    return reverse_idx_map\n\n\ndef prepare_dynamic_batch(\n    data: TensorDict,\n    max_token_len: int,\n    dp_group=None,\n    num_batches_divided_by=None,\n    same_micro_num_in_dp=True,\n    min_num_micro_batch=None,\n    use_dynamic_bsz_balance=True,\n) -> tuple[list[TensorDict], list[list[int]]]:\n    \"\"\"\n    Prepare a batch for dynamic batching.\n\n    Args:\n        data (Tensordict): The input data.\n        max_token_len (int): The maximum token length for dynamic batching.\n\n    Returns:\n        Tuple[List[Tensordict], List[List[int]]]: A tuple containing a list of Tensordict objects\n        and a list of index lists.\n    \"\"\"\n    batch, batch_idx_list = rearrange_micro_batches(\n        data,\n        max_token_len=max_token_len,\n        dp_group=dp_group,\n        num_batches_divided_by=num_batches_divided_by,\n        same_micro_num_in_dp=same_micro_num_in_dp,\n        min_num_micro_batch=min_num_micro_batch,\n        use_dynamic_bsz_balance=use_dynamic_bsz_balance,\n    )\n    micro_batches = []\n    for i, batch_idx in enumerate(batch_idx_list):\n        micro_batches.append(batch)\n\n    return micro_batches, batch_idx_list\n\n\ndef restore_dynamic_batch(data: torch.Tensor, batch_idx_list: list[list[int]]) -> torch.Tensor:\n    \"\"\"\n    Restore a batch from dynamic batching.\n\n    Args:\n        data (torch.Tensor): The input data.\n        batch_idx_list (List[List[int]]): The list of index lists.\n\n    Returns:\n        torch.Tensor: The restored data.\n    \"\"\"\n    indices = list(chain.from_iterable(batch_idx_list))\n    batch_size = data.shape[0]\n    assert len(indices) == batch_size, f\"{len(indices)} vs. {batch_size}\"\n    revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)\n\n    if data.is_nested:\n        tensors = [data[i] for i in revert_indices]\n        reverted_data = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)\n    else:\n        reverted_data = data[revert_indices]\n\n    return reverted_data\n"
  },
  {
    "path": "siirl/utils/model_utils/tensordict_utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nfrom typing import Iterator\n\nimport torch\nfrom tensordict import TensorDict\n# Handle different tensordict versions - NonTensorData location varies\ntry:\n    from tensordict import NonTensorData, NonTensorStack\nexcept ImportError:\n    from tensordict.tensorclass import NonTensorData, NonTensorStack\n\n\ndef assign_non_tensor_dict(tensor_dict: TensorDict, non_tensor_dict: dict):\n    for key, val in non_tensor_dict.items():\n        assign_non_tensor_data(tensor_dict=tensor_dict, key=key, val=val)\n    return tensor_dict\n\n\ndef assign_non_tensor_data(tensor_dict: TensorDict, key, val):\n    tensor_dict[key] = NonTensorData(val)\n\n\ndef assign_non_tensor(tensordict: TensorDict, **kwargs):\n    for key, val in kwargs.items():\n        assign_non_tensor_data(tensor_dict=tensordict, key=key, val=val)\n    return tensordict\n\n\ndef unwrap_non_tensor_data(data):\n    if isinstance(data, NonTensorData):\n        return data.data\n    return data\n\n\ndef get_non_tensor_data(data: TensorDict, key: str, default):\n    output = data.get(key, default)\n    return unwrap_non_tensor_data(output)\n\n\ndef get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) -> TensorDict:\n    \"\"\"\n\n    Args:\n        data_dict:\n        meta_info:\n\n    Returns:\n\n    \"\"\"\n    if non_tensor_dict is None:\n        non_tensor_dict = {}\n\n    batch_size = None\n\n    for key, val in tensor_dict.items():\n        if isinstance(val, list):\n            for v in val:\n                assert not isinstance(v, torch.Tensor), (\n                    \"Passing a list makes the data NonTensorStack, \"\n                    \"which doesn't support torch.Tensor. Please convert to numpy first\"\n                )\n        assert isinstance(val, torch.Tensor | list)\n\n        if batch_size is None:\n            batch_size = val.size(0) if isinstance(val, torch.Tensor) else len(val)\n        else:\n            val_batch_size = val.size(0) if isinstance(val, torch.Tensor) else len(val)\n            assert val_batch_size == batch_size, (\n                f\"Batch size of tensor {key} is not consistent with other tensors. \"\n                f\"Expected {batch_size}, got {val_batch_size}\"\n            )\n\n    if batch_size is None:\n        batch_size = []\n    else:\n        batch_size = [batch_size]\n\n    for key, val in non_tensor_dict.items():\n        assert key not in tensor_dict\n        tensor_dict[key] = NonTensorData(val)\n\n    return TensorDict(source=tensor_dict, batch_size=batch_size)\n\n\ndef index_select_tensor_dict(batch: TensorDict, indices: torch.Tensor | list[int]) -> TensorDict:\n    \"\"\"Index a tensor dict with a tensor of indices.\"\"\"\n    if isinstance(indices, list):\n        indices = torch.tensor(indices)\n\n    assert indices.dim() == 1, \"indices must be a 1D tensor\"\n\n    data_dict = {}\n    batch_size = indices.shape[0]\n\n    if batch is not None:\n        for key, tensor in batch.items():\n            if isinstance(tensor, torch.Tensor) and not tensor.is_nested:\n                data_dict[key] = tensor[indices]\n            elif isinstance(tensor, torch.Tensor) and tensor.is_nested:\n                data_dict[key] = torch.nested.as_nested_tensor([tensor[idx] for idx in indices], layout=torch.jagged)\n            else:\n                # This handles NonTensorStack (indexable by batch dim) and NonTensorData (scalar metadata).\n                if tensor.shape:\n                    data_dict[key] = tensor[indices]\n                else:\n                    data_dict[key] = tensor\n        selected_batch = TensorDict(source=data_dict, batch_size=batch_size)\n    else:\n        selected_batch = None\n\n    return selected_batch\n\n\ndef union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:\n    \"\"\"Union two tensordicts.\"\"\"\n    assert tensor_dict1.batch_size == tensor_dict2.batch_size, (\n        f\"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}\"\n    )\n    for key in tensor_dict2.keys():\n        if key not in tensor_dict1.keys():\n            tensor_dict1[key] = tensor_dict2[key]\n        else:\n            if isinstance(tensor_dict2[key], torch.Tensor):\n                assert tensor_dict1[key].equal(tensor_dict2[key]), (\n                    f\"{key} in tensor_dict1 and tensor_dict2 are not the same object\"\n                )\n            else:\n                # non-tensor\n                assert tensor_dict1[key] == tensor_dict2[key], (\n                    f\"{key} in tensor_dict1 and tensor_dict2 are not the same object\"\n                )\n\n    return tensor_dict1\n\n\ndef make_iterator(tensordict: TensorDict, mini_batch_size, epochs, seed=None, dataloader_kwargs=None):\n    from torch.utils.data import DataLoader\n\n    assert tensordict.batch_size[0] % mini_batch_size == 0, f\"{tensordict.batch_size[0]} % {mini_batch_size} != 0\"\n    # we can directly create a dataloader from TensorDict\n    if dataloader_kwargs is None:\n        dataloader_kwargs = {}\n\n    if seed is not None:\n        generator = torch.Generator()\n        generator.manual_seed(seed)\n    else:\n        generator = None\n\n    assert isinstance(dataloader_kwargs, dict)\n    train_dataloader = DataLoader(\n        dataset=tensordict, batch_size=mini_batch_size, collate_fn=lambda x: x, generator=generator, **dataloader_kwargs\n    )\n\n    def get_data():\n        for _ in range(epochs):\n            yield from train_dataloader\n\n    return iter(get_data())\n\n\ndef assert_tensordict_eq(tensordict1: TensorDict, tensordict2: TensorDict):\n    assert set(tensordict1.keys()) == set(tensordict2.keys())\n\n    for key in tensordict1.keys():\n        val = tensordict1[key]\n        val2 = tensordict2[key]\n\n        assert type(val) is type(val2), f\"The type of {key} must be the same. Got {type(val)} vs {type(val2)}\"\n\n        if isinstance(val, torch.Tensor):\n            if val.is_nested:\n                assert val.is_nested and val2.is_nested, (\n                    f\"Both tensors must be nested tensors. {val.is_nested=}, {val2.is_nested=}\"\n                )\n                t1, t2 = val.unbind(), val2.unbind()\n                assert len(t1) == len(t2), f\"Nested tensor should have the same lengths. {len(t1)=} vs {len(t2)=}\"\n                for c1, c2 in zip(t1, t2, strict=True):\n                    assert torch.equal(c1, c2), f\"Nested tensor components have different values. {c1=} vs {c2=}\"\n            else:\n                assert torch.all(torch.eq(val, val2)).item()\n        else:\n            assert val == val2\n\n\ndef pop(tensordict: TensorDict, keys: Iterator[str]) -> TensorDict:\n    tensor_output = {}\n    non_tensor_output = {}\n    for key in keys:\n        output = tensordict.get(key)\n        if isinstance(output, torch.Tensor):\n            tensor_output[key] = tensordict.pop(key)\n        elif isinstance(output, NonTensorStack):\n            tensor_output[key] = tensordict.pop(key).tolist()\n        else:\n            assert isinstance(output, NonTensorData)\n            non_tensor_output[key] = tensordict.pop(key)\n\n    return get_tensordict(tensor_output, non_tensor_output)\n\n\ndef pad_to_divisor(data: TensorDict, size_divisor: int):\n    \"\"\"Pad a TensorDict to size divisible by size_divisor\n\n    Args:\n        size_divisor (int): size divisor\n\n    Returns:\n        data: (TensorDict): the padded TensorDict\n        pad_size (int)\n    \"\"\"\n    assert isinstance(data, TensorDict), \"data must be a TensorDict\"\n    if len(data) % size_divisor != 0:\n        pad_size = size_divisor - len(data) % size_divisor\n        padding_protos = []\n        remaining_pad = pad_size\n        while remaining_pad > 0:\n            take_size = min(remaining_pad, len(data))\n            padding_protos.append(data[:take_size])\n            remaining_pad -= take_size\n        data_padded = torch.cat([data] + padding_protos)\n    else:\n        if len(data) == 0:\n            logging.warning(\"padding a TensorDict with no item, no changed made\")\n        pad_size = 0\n        data_padded = data\n    return data_padded, pad_size\n\n\ndef unpad(data: TensorDict, pad_size):\n    \"\"\"Unpad the data proto with pad_size. i.e. `data[:-pad_size]`\"\"\"\n    if pad_size != 0:\n        data = data[:-pad_size]\n    return data\n"
  },
  {
    "path": "siirl/utils/model_utils/torch_dtypes.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nAdapted from Cruise.\n\"\"\"\n\nfrom typing import Union\n\nimport torch\n\nHALF_LIST = [16, \"16\", \"fp16\", \"float16\", torch.float16]\nFLOAT_LIST = [32, \"32\", \"fp32\", \"float32\", torch.float32]\nBFLOAT_LIST = [\"bf16\", \"bfloat16\", torch.bfloat16]\n\n\nclass PrecisionType:\n    \"\"\"Type of precision used.\n\n    >>> PrecisionType.HALF == 16\n    True\n    >>> PrecisionType.HALF in (16, \"16\")\n    True\n    \"\"\"\n\n    HALF = \"16\"\n    FLOAT = \"32\"\n    FULL = \"64\"\n    BFLOAT = \"bf16\"\n    MIXED = \"mixed\"\n\n    @staticmethod\n    def supported_type(precision: Union[str, int]) -> bool:\n        return any(x == precision for x in PrecisionType)\n\n    @staticmethod\n    def supported_types() -> list[str]:\n        return [x.value for x in PrecisionType]\n\n    @staticmethod\n    def is_fp16(precision):\n        return precision in HALF_LIST\n\n    @staticmethod\n    def is_fp32(precision):\n        return precision in FLOAT_LIST\n\n    @staticmethod\n    def is_bf16(precision):\n        return precision in BFLOAT_LIST\n\n    @staticmethod\n    def to_dtype(precision):\n        if precision in HALF_LIST:\n            return torch.float16\n        elif precision in FLOAT_LIST:\n            return torch.float32\n        elif precision in BFLOAT_LIST:\n            return torch.bfloat16\n        else:\n            raise RuntimeError(f\"unexpected precision: {precision}\")\n\n    @staticmethod\n    def to_str(precision):\n        if precision == torch.float16:\n            return \"fp16\"\n        elif precision == torch.float32:\n            return \"fp32\"\n        elif precision == torch.bfloat16:\n            return \"bf16\"\n        else:\n            raise RuntimeError(f\"unexpected precision: {precision}\")\n"
  },
  {
    "path": "siirl/utils/model_utils/torch_functional.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nContain small torch utilities\n\"\"\"\n\nimport math\nfrom contextlib import contextmanager\nfrom typing import Dict, List, Optional, Union\n\nimport torch\nimport torch.distributed\nimport torch.nn.functional as F\nfrom tensordict import TensorDict\nfrom torch import nn\nfrom torch.optim import Optimizer\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom transformers import PreTrainedTokenizer\n\nfrom siirl.utils.extras.device import get_device_name, get_torch_device\n\ntry:\n    from flash_attn.ops.triton.cross_entropy import cross_entropy_loss\n\n    FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = True\nexcept ImportError:\n    FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False\n\n\ndef gather_from_labels(data, label):\n    \"\"\"Gather the label from data. The value in label should be [0, vocab_size)\n\n    Args:\n        data: (..., vocab_size)\n        label (torch.IntTensor) : (...,)\n\n    Returns:\n\n    \"\"\"\n\n    output = torch.gather(data, -1, label.unsqueeze(-1)).squeeze(-1)\n    return output\n\n\ndef logprobs_from_logits(logits, labels, inplace_backward=True):\n    \"\"\"\n    Compute per-token log-probabilities for the given labels.\n\n    Uses a Flash-Attention–based cross-entropy (if available) for efficient backward,\n    otherwise falls back to a standard log-softmax+gather approach.\n\n    See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591\n\n    Args:\n        logits (Tensor): Model outputs of shape (..., vocab_size).\n        labels (LongTensor): True class indices of shape matching logits[..., :-1].\n        inplace_backward (bool): If True and Flash-Attn is available, perform backward in-place.\n\n    Returns:\n        Tensor: Log-probabilities of the target labels, shape logits.shape[:-1].\n    \"\"\"\n    if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE:\n        batch_dim = logits.shape[:-1]\n        last_dim = logits.shape[-1]\n        logits = logits.reshape(-1, last_dim)\n        labels = labels.reshape(-1)\n        output = logprobs_from_logits_flash_attn(logits, labels, inplace_backward=inplace_backward)\n        output = output.view(*batch_dim)\n    else:\n        output = logprobs_from_logits_v2(logits, labels)\n    return output\n\n\ndef logprobs_from_logits_flash_attn(logits, labels, inplace_backward=True):\n    output = cross_entropy_loss(logits, labels, inplace_backward=inplace_backward)\n    assert isinstance(output, tuple), (\n        \"please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses].\"\n    )\n    return -output[0]\n\n\ndef logprobs_from_logits_naive(logits, labels):\n    logp = F.log_softmax(logits, dim=-1)\n    logpy = gather_from_labels(logp, labels)\n    return logpy\n\n\ndef logprobs_from_logits_v2(logits: torch.FloatTensor, labels):\n    \"\"\"\n    A memory efficient implementation of logprobs_from_logits\n    \"\"\"\n    if logits.dtype in [torch.float32, torch.float64]:\n        logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)\n        # loop to reduce peak mem consumption\n        logsumexp_values = torch.stack([torch.logsumexp(logit, dim=-1) for logit in logits])\n        logprobs_labels = logits_labels - logsumexp_values  # log_softmax(x_i) = x_i - logsumexp(x)\n    else:\n        # logsumexp approach is unstable with bfloat16, fall back to slightly less efficient approach\n        logprobs_labels = []\n        for row_logits, row_labels in zip(logits, labels):  # loop to reduce peak mem consumption\n            row_logprobs = F.log_softmax(row_logits, dim=-1)\n            row_logprobs_labels = row_logprobs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)\n            logprobs_labels.append(row_logprobs_labels)\n        logprobs_labels = torch.stack(logprobs_labels)\n    return logprobs_labels\n\n\ndef clip_by_value(x, tensor_min, tensor_max):\n    \"\"\"\n    Tensor extension to torch.clamp\n    https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713\n    \"\"\"\n    clipped = torch.max(torch.min(x, tensor_max), tensor_min)\n    return clipped\n\n\ndef entropy_from_logits(logits: torch.Tensor):\n    \"\"\"Calculate entropy from logits.\"\"\"\n    pd = torch.nn.functional.softmax(logits, dim=-1)\n    entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)\n    return entropy\n\n\ndef entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 2048):\n    \"\"\"Memory-efficient entropy calculation with chunking.\"\"\"\n    entropy = torch.zeros(logits.shape[0], device=logits.device)\n    for i in range(0, logits.shape[0], chunk_size):\n        logits_chunk = logits[i : i + chunk_size].float()\n        pd_chunk = torch.nn.functional.softmax(logits_chunk, dim=-1)\n        entropy_chunk = torch.logsumexp(logits_chunk, dim=-1) - torch.sum(pd_chunk * logits_chunk, dim=-1)\n        entropy[i : i + chunk_size] = entropy_chunk\n    return entropy\n\n\ndef masked_sum(values, mask, axis=None):\n    \"\"\"Compute mean of tensor with a masked values.\"\"\"\n    # If NaNs exist out of mask, replace NaNs in values with a value that\n    # won't affect the sum (e.g., 0 for masked regions)\n    valid_values = torch.where(mask.bool(), values, 0.0)\n    return (valid_values * mask).sum(axis=axis)\n\n\ndef masked_mean(values, mask, axis=None):\n    \"\"\"\n    Compute the mean of `values` over elements selected by `mask`.\n\n    Args:\n        values (Tensor): Input tensor.\n        mask (Tensor): Boolean or numeric mask of the same shape as `values`.\n        axis (int or tuple of int, optional): Dimension(s) along which to compute the mean.\n            Defaults to None (over all elements).\n\n    Returns:\n        Tensor: Masked mean, with shape equal to `values` reduced over `axis`.\n    \"\"\"\n    s = masked_sum(values, mask, axis)\n    return s / (mask.sum(axis=axis) + 1e-8)\n\n\ndef masked_var(values, mask, unbiased=True):\n    \"\"\"Compute variance of tensor with masked values.\"\"\"\n    mean = masked_mean(values, mask)\n    centered_values = values - mean\n    variance = masked_mean(centered_values**2, mask)\n    if unbiased:\n        mask_sum = mask.sum()\n        if mask_sum == 0:\n            raise ValueError(\"At least one element in the mask has to be 1.\")\n        # note that if mask_sum == 1, then there is a division by zero issue\n        # to avoid it you just need to use a larger minibatch_size\n        if mask_sum == 1:\n            raise ValueError(\"The sum of the mask is one, which can cause a division by zero.\")\n        bessel_correction = mask_sum / (mask_sum - 1)\n        variance = variance * bessel_correction\n    return variance\n\n\ndef masked_whiten(values, mask, shift_mean=True):\n    \"\"\"\n    Whiten `values` by normalizing with mean and variance computed over `mask`.\n\n    Args:\n        values (torch.Tensor): Input tensor.\n        mask (torch.Tensor): Boolean tensor of same shape, selects elements for stats.\n        shift_mean (bool): If True (default), output is zero-mean;\n                           if False, the original mean is re-added after scaling.\n\n    Returns:\n        torch.Tensor: Whitened tensor of same shape as `values`.\n    \"\"\"\n    mean, var = masked_mean(values, mask), masked_var(values, mask)\n    whitened = (values - mean) * torch.rsqrt(var + 1e-8)\n    if not shift_mean:\n        whitened += mean\n    return whitened\n\n\ndef get_response_mask(response_id: torch.Tensor, eos_token: Union[int, List[int]] = 2, dtype=torch.int64):\n    \"\"\"\n    end of sentence token can be int or list: 1 or [1, 2]\n    e.g.\n    response_id = torch.tensor([[20, 10, 34, 1, 0, 0, 0],\n                                [78, 0, 76, 2, 1, 0, 0],\n                                [23, 98, 1, 0, 0, 0, 0],\n                                [33, 3, 98, 45, 1, 0, 0]])\n    #eos_token=1\n    response_mask:  tensor([[1, 1, 1, 1, 0, 0, 0],\n                            [1, 1, 1, 1, 1, 0, 0],\n                            [1, 1, 1, 0, 0, 0, 0],\n                            [1, 1, 1, 1, 1, 0, 0]])\n    #eos_token=[1,2]\n    response_mask:  tensor([[1, 1, 1, 1, 0, 0, 0],\n                            [1, 1, 1, 1, 0, 0, 0],\n                            [1, 1, 1, 0, 0, 0, 0],\n                            [1, 1, 1, 1, 1, 0, 0]])\n    \"\"\"\n    eos_mask = torch.isin(response_id, torch.tensor(eos_token, device=response_id.device)).int()\n    return (eos_mask.cumsum(dim=1) - eos_mask).eq(0).to(dtype)\n\n\ndef get_eos_mask(response_id: torch.Tensor, eos_token: int = 2, dtype=torch.int64):\n    \"\"\"\n    Get EOS mask for response sequences.\n    \n    e.g. end of sentence token=1\n    response_id: [0, 0, 2, 42, 3, 5, 1, 0, 0]\n    eos_mask:     [1, 1, 1, 1,  1, 1, 1, 0, 0]\n    \n    This is a simplified version of get_response_mask for single EOS token.\n    Used for VLA embodied rollout compatibility.\n    \n    Args:\n        response_id: Token IDs tensor\n        eos_token: End of sequence token ID (single int)\n        dtype: Output dtype\n        \n    Returns:\n        Boolean mask where 1 indicates valid tokens before EOS\n    \"\"\"\n    eos_mask = response_id.eq(eos_token).long()\n    eos_mask = (torch.cumsum(eos_mask, dim=1) - eos_mask).bool()\n    eos_mask = torch.logical_not(eos_mask).to(dtype)\n    return eos_mask\n\n\ndef compute_grad_norm(model: nn.Module):\n    total_grad_square = 0\n    for param in model.parameters():\n        if param.grad is not None:\n            total_grad_square += torch.sum(torch.square(param.grad.detach())).item()\n    return total_grad_square\n\n\ndef broadcast_dict_tensor(tensors: Union[Dict[str, torch.Tensor], TensorDict], src, group):\n    \"\"\"\n    TODO: optimize this. Technically, we only need one broadcast\n    \"\"\"\n\n    for key in tensors.sorted_keys:\n        if isinstance(tensors[key], torch.Tensor):\n            torch.distributed.broadcast(tensors[key], src=src, group=group, async_op=False)\n\n\ndef allgather_dict_tensors(tensors: Union[Dict[str, torch.Tensor], TensorDict], size, group, dim=0):\n    \"\"\"\n    TODO: optimize this.\n    - We can use async ops\n    - We can use only one allgather\n    Args:\n        tensors:\n        size:\n        group:\n\n    Returns:\n\n    \"\"\"\n    if isinstance(tensors, TensorDict):\n        is_tensor_dict = True\n        tensors_as_dict = tensors.to_dict()\n    else:\n        tensors_as_dict = tensors\n        is_tensor_dict = False\n\n    output = {}\n    sorted_keys = sorted(tensors_as_dict.keys())\n    for key in sorted_keys:\n        val = tensors_as_dict[key]\n        output[key] = [torch.empty_like(val) for _ in range(size)]\n        torch.distributed.all_gather(output[key], val, group=group, async_op=False)\n        output[key] = torch.cat(output[key], dim=dim)\n\n    if is_tensor_dict:\n        output = TensorDict(source=output, batch_size=tensors.batch_size[0] * size)\n\n    return output\n\n\ndef split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> List[TensorDict]:\n    assert tensors.batch_size[0] % batch_size == 0, (\n        f\"input data batch size: {tensors.batch_size[0]}, split batch size: {batch_size}\"\n    )\n    return tensors.split(batch_size)\n\n\ndef pad_2d_list_to_length(response, pad_token_id, max_length=None):\n    \"\"\"\n    pad a 2D list (e.g. responses, logprobs) to a 2D tensor.\n    \"\"\"\n    response_length = max(len(sub_list) for sub_list in response)\n    target_length = max_length if max_length is not None and max_length > response_length else response_length\n    padded_response = [tuple(sub_list) + (pad_token_id,) * (target_length - len(sub_list)) for sub_list in response]\n    tensor = torch.tensor(padded_response)\n    return tensor\n\n\ndef pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False):\n    \"\"\"\n    pad a 2D tensors (e.g. responses, logprobs) in the last dim to max_seq_length.\n    input shape: [bs, seq_length]\n    output shape: [bs, max_seq_length]\n    \"\"\"\n    if tensors.shape[-1] >= max_seq_len:\n        return tensors\n    # (0, max_seq_len - tensors.shape[-1]) means right pad to max_seq_length and no left pad\n    pad_tuple = (max_seq_len - tensors.shape[-1], 0) if left_pad else (0, max_seq_len - tensors.shape[-1])\n    return F.pad(tensors, pad_tuple, \"constant\", pad_token_id)\n\n\ndef postprocess_data(\n    input_ids: torch.Tensor,\n    attention_mask: torch.Tensor,\n    max_length: int,\n    pad_token_id: int,\n    left_pad=True,\n    truncation=\"error\",\n):\n    \"\"\"Process tokenizer outputs to consistent shapes via padding/truncation.\n\n    Args:\n        input_ids: Token indices [batch_size, seq_len]\n        attention_mask: Mask [batch_size, seq_len]\n        max_length: Target sequence length\n        pad_token_id: Padding token ID\n        left_pad: Pad left if True\n        truncation: \"left\", \"right\" or \"error\"\n\n    Returns:\n        (input_ids, attention_mask) padded/truncated to max_length\n    \"\"\"\n    assert truncation in [\"left\", \"right\", \"middle\", \"error\"]\n    assert input_ids.ndim == 2\n\n    sequence_length = input_ids.shape[-1]\n    if sequence_length < max_length:\n        input_ids = pad_sequence_to_length(\n            input_ids, max_seq_len=max_length, pad_token_id=pad_token_id, left_pad=left_pad\n        )\n        attention_mask = pad_sequence_to_length(\n            attention_mask, max_seq_len=max_length, pad_token_id=0, left_pad=left_pad\n        )\n    elif sequence_length > max_length:\n        if truncation == \"left\":\n            # actually, left truncation may not be reasonable\n            input_ids = input_ids[:, -max_length:]\n            attention_mask = attention_mask[:, -max_length:]\n        elif truncation == \"right\":\n            input_ids = input_ids[:, :max_length]\n            attention_mask = attention_mask[:, :max_length]\n        elif truncation == \"middle\":\n            left_half = max_length // 2\n            right_half = max_length - left_half\n            input_ids = torch.cat([input_ids[:, :left_half], input_ids[:, -right_half:]], dim=-1)\n            attention_mask = torch.cat([attention_mask[:, :left_half], attention_mask[:, -right_half:]], dim=-1)\n        elif truncation == \"error\":\n            raise NotImplementedError(f\"{sequence_length=} is larger than {max_length=}\")\n        else:\n            raise NotImplementedError(f\"Unknown truncation method {truncation}\")\n\n    return input_ids, attention_mask\n\n\ndef tokenize_and_postprocess_data(\n    prompt: str, tokenizer: PreTrainedTokenizer, max_length: int, pad_token_id: int, left_pad=True, truncation=\"error\"\n):\n    \"\"\"Tokenize text and process outputs to consistent tensor shapes.\n\n    Args:\n        prompt: Input text to tokenize\n        tokenizer: HuggingFace tokenizer instance\n        max_length: Target sequence length\n        pad_token_id: Padding token ID\n        left_pad: Pad left if True\n        truncation: Truncation strategy (\"left\"/\"right\"/\"error\")\n\n    Returns:\n        Tuple of (input_ids, attention_mask) from postprocess_data\n    \"\"\"\n    input_data = tokenizer(prompt, return_tensors=\"pt\", add_special_tokens=False)\n    input_ids = input_data[\"input_ids\"]\n    attention_mask = input_data[\"attention_mask\"]\n\n    return postprocess_data(input_ids, attention_mask, max_length, pad_token_id, left_pad, truncation)\n\n\ndef remove_pad_token(input_ids: torch.Tensor, attention_mask: torch.Tensor):\n    \"\"\"Remove the pad token.\n\n    Args:\n        input_ids shape: [bs, seq_length]\n        attention_mask shape: [bs, seq_length]\n    Returns:\n        no_padding_batch(List[List[int]]): contains the rmpad token ids per query.\n    \"\"\"\n    no_padding_batch = []\n    for ids, mask in zip(input_ids, attention_mask):\n        no_padding_batch.append((ids[len(ids) - mask.sum() :]).cpu().numpy().tolist())\n    return no_padding_batch\n\n\ndef log_probs_from_logits_response(input_ids, logits, response_length):\n    \"\"\"Compute the response log_probs from full logits. Note that logits = model(input_ids)\n\n    Args:\n        input_ids: [batch_size, seqlen]\n        logits: [batch_size, seqlen, vocab_size]\n\n    Returns:\n        response_log_prob:\n    \"\"\"\n    response_logits = logits[:, -response_length - 1 : -1]\n    response = input_ids[:, -response_length:]\n    response_log_prob = logprobs_from_logits(logits=response_logits, labels=response)\n    return response_log_prob\n\n\ndef log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length):\n    \"\"\"Compute the log_probs from logits with rmpad logits and pad input. Note that\n    logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between\n    logits and input_ids.\n    The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive\n    for large vocab_size\n\n    Args:\n        input_ids: [batch_size, seqlen]\n        attention_mask: [batch_size, seqlen]\n        logits_rmpad: [total_nnz, vocab_size]\n        response_length: int\n    \"\"\"\n    from flash_attn.bert_padding import pad_input, unpad_input\n\n    batch_size, seqlen = input_ids.shape\n    input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask)\n    input_ids_rmpad = input_ids_rmpad.squeeze(-1)\n    input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)\n    full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)  # (total_nnz,)\n    full_output = pad_input(\n        hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen\n    )\n    output = full_output.squeeze(-1)[:, -response_length - 1 : -1]  # [batch_size, response_length]\n    return output\n\n\ndef log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices, batch_size, seqlen, response_length):\n    \"\"\"Compute the log_probs from logits with rmpad input_ids and logits. Note that\n    logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between\n    logits and input_ids.\n    The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive\n    for large vocab_size\n\n    Args:\n        input_ids_rmpad: [1, total_nnz]\n        logits_rmpad: [total_nnz, vocab_size]\n        indices: [total_nnz]\n        batch_size: int\n        seqlen: int\n        response_length: int\n    \"\"\"\n    from flash_attn.bert_padding import pad_input\n\n    input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # transpose back to [total_nnz, 1]\n    input_ids_rmpad = input_ids_rmpad.squeeze(-1)\n    input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)\n    full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)  # (total_nnz,)\n    full_output = pad_input(\n        hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen\n    )\n    output = full_output.squeeze(-1)[:, -response_length - 1 : -1]  # [batch_size, response_length]\n    return output\n\n\ndef post_process_logits(input_ids, logits, temperature, top_k, top_p):\n    if temperature != 1.0:\n        logits = logits.div_(temperature)  # inplace operation to avoid OOM\n    # TODO: add them back\n    # if top_k is not None and top_k > 0:\n    #     logits = TopKLogitsWarper(top_k=top_k)(input_ids, logits)\n    # if top_p is not None and top_p < 1.0 and top_p > 0.0:\n    #     logits = TopPLogitsWarper(top_p=top_p)(input_ids, logits)\n    return logits\n\n\n\"\"\"\nOptimizer related\n\"\"\"\n\n\ndef get_cosine_schedule_with_warmup(\n    optimizer: Optimizer,\n    num_warmup_steps: int,\n    num_training_steps: int,\n    min_lr_ratio: float = 0.0,\n    num_cycles: float = 0.5,\n    last_epoch: int = -1,\n):\n    \"\"\"\n    Create a schedule with a learning rate that decreases following the values of the cosine function between the\n    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the\n    initial lr set in the optimizer.\n    Args:\n        optimizer (:class:`~torch.optim.Optimizer`):\n            The optimizer for which to schedule the learning rate.\n        num_warmup_steps (:obj:`int`):\n            The number of steps for the warmup phase.\n        num_training_steps (:obj:`int`):\n            The total number of training steps.\n        min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):\n            The minimum lr ratio w.r.t the maximum.\n        num_cycles (:obj:`float`, `optional`, defaults to 0.5):\n            The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0\n            following a half-cosine).\n        last_epoch (:obj:`int`, `optional`, defaults to -1):\n            The index of the last epoch when resuming training.\n    Return:\n        :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n    \"\"\"\n    min_lr_ratio = 0.0 if min_lr_ratio is None else min_lr_ratio\n    assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0\n    coef = (1 - min_lr_ratio) * 0.5\n    intercept = (1 + min_lr_ratio) * 0.5\n\n    def lr_lambda(current_step):\n        if current_step < num_warmup_steps:\n            return min_lr_ratio + (1.0 - min_lr_ratio) * (float(current_step) / float(max(1, num_warmup_steps)))\n        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))\n        x = math.cos(math.pi * float(num_cycles) * 2.0 * progress)\n        return max(min_lr_ratio, x * coef + intercept)\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\n\ndef get_constant_schedule_with_warmup(\n    optimizer: Optimizer,\n    num_warmup_steps: int,\n    last_epoch: int = -1,\n):\n    \"\"\"\n    Create a constant LR schedule with a linear warmup phase.\n\n    Args:\n        optimizer (Optimizer): Wrapped optimizer.\n        num_warmup_steps (int): Number of steps to ramp up the LR from 0 to initial value.\n        last_epoch (int, optional): The index of the last epoch when resuming training. Defaults to -1.\n\n    Returns:\n        LambdaLR: Scheduler that increases LR linearly during warmup, then holds it constant.\n    \"\"\"\n\n    def lr_lambda(current_step):\n        if current_step < num_warmup_steps:\n            return float(current_step) / float(max(1.0, num_warmup_steps))\n        return 1.0\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\n\ndef prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds):\n    # create causal mask\n    # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n    combined_attention_mask = None\n    if input_shape[-1] > 1:\n        combined_attention_mask = _make_causal_mask(\n            input_shape,\n            inputs_embeds.dtype,\n            device=inputs_embeds.device,\n        )\n\n    if attention_mask is not None:\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(\n            inputs_embeds.device\n        )\n        combined_attention_mask = (\n            expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask\n        )\n\n    return combined_attention_mask\n\n\n# Copied from transformers.models.bart.modeling_bart._make_causal_mask\ndef _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)\n\n\n# Copied from transformers.models.bart.modeling_bart._expand_mask\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\n\ndef get_unpad_data(attention_mask):\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))\n    return (\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n    )\n\n\ndef get_wsd_schedule_with_warmup(\n    optimizer: Optimizer,\n    num_warmup_steps: int,\n    num_training_steps: int,\n    min_lr_ratio: float = 0.0,\n    num_cycles: float = 0.5,\n    last_epoch: int = -1,\n    stable_ratio: float = 0.9,\n):\n    \"\"\"\n    Create a Warmup-Stable-Decay learning rate scheduler.\n\n    The schedule follows three phases:\n    1. Warmup: Learning rate increases linearly from 0 to the initial LR\n    2. Stable: Learning rate remains constant at the initial LR\n    3. Decay: Learning rate decreases following a cosine curve to min_lr_ratio * initial LR\n\n    Args:\n        optimizer (:class:`~torch.optim.Optimizer`):\n            The optimizer for which to schedule the learning rate.\n        num_warmup_steps (:obj:`int`):\n            The number of steps for the warmup phase.\n        num_training_steps (:obj:`int`):\n            The total number of training steps.\n        min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):\n            The minimum learning rate ratio w.r.t the initial learning rate.\n        num_cycles (:obj:`float`, `optional`, defaults to 0.5):\n            The number of waves in the cosine schedule during decay phase.\n        last_epoch (:obj:`int`, `optional`, defaults to -1):\n            The index of the last epoch when resuming training.\n        stable_ratio (:obj:`float`, `optional`, defaults to 0.0):\n            The ratio of non-warmup steps that should maintain a constant learning rate.\n            Set to 0.0 to behave exactly like cosine schedule.\n\n    Return:\n        :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n    \"\"\"\n    remaining_steps = max(0, num_training_steps - num_warmup_steps)\n    num_stable_steps = int(remaining_steps * stable_ratio)\n    num_decay_steps = remaining_steps - num_stable_steps\n\n    def lr_lambda(current_step):\n        if current_step < num_warmup_steps:\n            return float(current_step) / float(max(1, num_warmup_steps))\n        if current_step < num_warmup_steps + num_stable_steps:\n            return 1.0\n        if current_step < num_training_steps:\n            progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps))\n            value = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))\n            return (1.0 - min_lr_ratio) * value + min_lr_ratio\n        return min_lr_ratio\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\n\n@contextmanager\ndef check_device_is_available():\n    \"\"\"\n    Some modules must be imported after CUDA is initialized. Such as sglang's sharding manager.\n\n    This context manager checks if CUDA is available and raises an error if it is not.\n    \"\"\"\n    if not get_torch_device().is_available():\n        raise RuntimeError(\"Device {} must be initialized before importing this module.\".format(get_device_name()))\n\n    yield\n\n\ndef distributed_mean_max_min_std(local_tensor, compute_max=True, compute_min=True, compute_std=True):\n    \"\"\"Compute distributed statistics across all processes.\n\n    Args:\n        local_tensor: Tensor containing local values\n        compute_max: Include maximum value calculation\n        compute_min: Include minimum value calculation\n        compute_std: Include standard deviation calculation\n\n    Returns:\n        Tuple containing (mean, max, min, std) in this order. None for disabled metrics.\n    \"\"\"\n    # Sum the local tensor across all processes\n    local_sum = torch.sum(local_tensor)\n    local_num = torch.tensor(torch.numel(local_tensor), device=get_device_name())\n\n    torch.distributed.all_reduce(local_sum, op=torch.distributed.ReduceOp.SUM)\n    torch.distributed.all_reduce(local_num, op=torch.distributed.ReduceOp.SUM)\n\n    global_mean = local_sum / local_num\n\n    if compute_max:\n        local_max = torch.max(local_tensor)\n        torch.distributed.all_reduce(local_max, op=torch.distributed.ReduceOp.MAX)\n    else:\n        local_max = None\n\n    if compute_min:\n        local_min = torch.min(local_tensor)\n        torch.distributed.all_reduce(local_min, op=torch.distributed.ReduceOp.MIN)\n    else:\n        local_min = None\n\n    if compute_std:\n        square_diff = torch.sum(torch.pow(local_tensor - global_mean, 2))\n        torch.distributed.all_reduce(square_diff, op=torch.distributed.ReduceOp.SUM)\n        global_std = torch.sqrt(square_diff / (local_num - 1))\n    else:\n        global_std = None\n\n    return global_mean, local_max, local_min, global_std\n\n\ndef distributed_masked_mean(local_tensor, local_mask):\n    \"\"\"Compute global mean of non-masked elements across distributed processes.\n\n    Args:\n        local_tensor (torch.Tensor): Input tensor with local values\n        local_mask (torch.Tensor): Binary mask (1=valid, 0=ignore) matching local_tensor shape\n\n    Returns:\n        torch.Tensor: Global mean of all valid elements across processes\n    \"\"\"\n    local_tensor = local_tensor * local_mask\n\n    local_sum = torch.sum(local_tensor)\n    local_num = torch.sum(local_mask)\n\n    torch.distributed.all_reduce(local_sum, op=torch.distributed.ReduceOp.SUM)\n    torch.distributed.all_reduce(local_num, op=torch.distributed.ReduceOp.SUM)\n\n    global_mean = local_sum / local_num\n    return global_mean\n"
  },
  {
    "path": "siirl/utils/model_utils/ulysses.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nUtilities for DeepSpeed Ulysses Sequence Parallelism.\nDeepSpeed Ulysses Paper: https://arxiv.org/abs/2309.14509\nInspired from: https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/sequence/layer.py\n\"\"\"\n\nfrom typing import Any, Optional, Tuple\n\nimport torch\nimport torch.distributed as dist\nfrom torch import Tensor\nfrom torch.distributed import ProcessGroup\n\n_ULYSSES_SEQUENCE_PARALLEL_GROUP = None\n\n\ndef set_ulysses_sequence_parallel_group(group: dist.ProcessGroup):\n    \"\"\"\n    Set ulysses sequence parallel process group.\n    \"\"\"\n    global _ULYSSES_SEQUENCE_PARALLEL_GROUP\n    _ULYSSES_SEQUENCE_PARALLEL_GROUP = group\n\n\ndef get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]:\n    \"\"\"\n    Get ulysses sequence parallel process group.\n    \"\"\"\n    global _ULYSSES_SEQUENCE_PARALLEL_GROUP\n    return _ULYSSES_SEQUENCE_PARALLEL_GROUP\n\n\ndef get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int:\n    \"\"\"\n    Get ulysses sequence parallel world size.\n    \"\"\"\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    return dist.get_world_size(group) if group else 1\n\n\ndef get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int:\n    \"\"\"\n    Get ulysses sequence parallel rank.\n    \"\"\"\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    return dist.get_rank(group) if group else 0\n\n\ndef gather_seq_scatter_heads(\n    x: Tensor,\n    seq_dim: int,\n    head_dim: int,\n    unpadded_dim_size: int = 0,\n    group: ProcessGroup = None,\n) -> Tensor:\n    \"\"\"\n    A func to sync embedding input with alltoall in sequence parallel\n    gather sequence dimension and scatter head dim:\n    e.g. seq_dim: 1, head_dim: 2\n    [bsz, seq/n, h, ...] -> [bsz, seq, h/n, ...]\n    \"\"\"\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    if not group:\n        return x\n    sp_world = get_ulysses_sequence_parallel_world_size(group)\n    x = SeqAllToAll.apply(group, x, head_dim, seq_dim)\n    if unpadded_dim_size and unpadded_dim_size % sp_world != 0:\n        padding_size = x.size(seq_dim) - unpadded_dim_size\n        x = _unpad_tensor(x, seq_dim, padding_size)\n    return x\n\n\ndef gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None) -> Tensor:\n    \"\"\"\n    A func to sync attention result with alltoall in sequence parallel\n    gather head dimension and scatter seq dim:\n    e.g. seq_dim: 1, head_dim: 2\n    [bsz, seq, h/n, ...] -> [bsz, seq/n, h, ...]\n    \"\"\"\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    if not group:\n        return x\n    dim_size = x.size(seq_dim)\n    sp_world = get_ulysses_sequence_parallel_world_size(group)\n    if dim_size % sp_world != 0:\n        padding_size = sp_world - (dim_size % sp_world)\n        x = _pad_tensor(x, seq_dim, padding_size)\n    return SeqAllToAll.apply(group, x, seq_dim, head_dim, False)\n\n\ndef _pad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:\n    shape = list(x.shape)\n    shape[dim] = padding_size\n    pad = torch.zeros(shape, dtype=x.dtype, device=x.device)\n    return torch.cat([x, pad], dim=dim)\n\n\ndef _unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:\n    slc = [slice(None)] * len(x.shape)\n    slc[dim] = slice(0, -padding_size)\n    return x[slc]\n\n\ndef slice_input_tensor(x: Tensor, dim: int, padding: bool = True, group: ProcessGroup = None) -> Tensor:\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    sp_world_size = dist.get_world_size(group)\n    sp_rank = get_ulysses_sequence_parallel_rank()\n    dim_size = x.size(dim)\n    # pad before slice\n    if padding and dim_size % sp_world_size:\n        padding_size = sp_world_size - (dim_size % sp_world_size)\n        x = _pad_tensor(x, dim, padding_size)\n    # slice the input tensor\n    parts = x.size(dim) // sp_world_size\n    slc = [slice(None)] * len(x.shape)\n    slc[dim] = slice(sp_rank * parts, (sp_rank + 1) * parts)\n    return x[slc].contiguous()\n\n\ndef all_to_all_tensor(\n    local_input: Tensor,\n    scatter_dim: int,\n    gather_dim: int,\n    group: Optional[dist.ProcessGroup] = None,\n    async_op: bool = False,\n):\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    seq_world_size = dist.get_world_size(group)\n    input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]\n    output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]\n    comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op)\n    if async_op:\n\n        def wait():\n            comm.wait()\n            return torch.cat(output_list, dim=gather_dim).contiguous()\n\n        return wait\n    return torch.cat(output_list, dim=gather_dim).contiguous()\n\n\ndef all_gather_tensor(local_tensor: Tensor, group: Optional[dist.ProcessGroup] = None, async_op: bool = False):\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    sp_world_size = dist.get_world_size(group=group)\n    output_shape = list(local_tensor.shape)\n    output_shape[0] = output_shape[0] * sp_world_size\n    output = torch.empty(output_shape, dtype=local_tensor.dtype, device=local_tensor.device)\n    dist.all_gather_into_tensor(output, local_tensor, group=group, async_op=async_op)\n    return output\n\n\nclass SeqAllToAll(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx: Any,\n        group: dist.ProcessGroup,\n        local_input: Tensor,\n        scatter_dim: int,\n        gather_dim: int,\n        async_op: bool = False,\n    ) -> Tensor:\n        ctx.group = group\n        ctx.scatter_dim = scatter_dim\n        ctx.gather_dim = gather_dim\n        ctx.async_op = async_op\n        return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op)\n\n    @staticmethod\n    def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:\n        input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous() if ctx.async_op else grad_output[0]\n        return (\n            None,\n            all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False),\n            None,\n            None,\n            None,\n            None,\n        )\n\n\nclass Gather(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx: Any,\n        group: dist.ProcessGroup,\n        local_tensor: Tensor,\n        gather_dim: int,\n        grad_scaler: bool = True,\n        async_op=False,\n    ) -> Tensor:\n        ctx.group = group\n        ctx.gather_dim = gather_dim\n        ctx.grad_scaler = grad_scaler\n        ctx.async_op = async_op\n\n        sp_world_size = dist.get_world_size(group=group)\n        ctx.sp_world_size = sp_world_size\n\n        sp_rank = dist.get_rank(group=group)\n        ctx.sp_rank = sp_rank\n\n        local_shape = list(local_tensor.size())\n        split_size = local_shape[0]\n        part_size = local_shape[gather_dim]  # store original size\n        ctx.part_size = part_size\n\n        output = all_gather_tensor(local_tensor, group, async_op)\n        return torch.cat(output.split(split_size, dim=0), dim=gather_dim)\n\n    @staticmethod\n    def backward(ctx: Any, grad_output: Tensor) -> Any:\n        if ctx.grad_scaler:\n            grad_output = grad_output * ctx.sp_world_size\n        return (\n            None,\n            grad_output.split(ctx.part_size, dim=ctx.gather_dim)[ctx.sp_rank].contiguous(),\n            None,\n            None,\n            None,\n            None,\n        )\n\n\ndef gather_outpus_and_unpad(\n    x: Tensor,\n    gather_dim: int,\n    unpad_dim: int = None,\n    padding_size: int = 0,\n    grad_scaler: bool = True,\n    group: Optional[dist.ProcessGroup] = None,\n):\n    \"\"\"\n    Gather a tensor across a process group and optionally unpad its padded elements.\n\n    Args:\n        x (Tensor): Input tensor to gather.\n        gather_dim (int): Dimension along which to gather across ranks.\n        unpad_dim (int, optional): Dimension from which to remove padding. If None, no unpadding.\n        padding_size (int): Number of padding elements to remove on `unpad_dim`. Defaults to 0.\n        grad_scaler (bool): Whether to apply gradient scaling during gather. Defaults to True.\n        group (ProcessGroup, optional): Process group for gathering. If None, uses\n            `get_ulysses_sequence_parallel_group()`. If still None, returns `x` unchanged.\n\n    Returns:\n        Tensor: The gathered tensor, with padding removed if requested.\n    \"\"\"\n    group = get_ulysses_sequence_parallel_group() if group is None else group\n    if group is None:\n        return x\n    x = Gather.apply(group, x, gather_dim, grad_scaler)\n    if unpad_dim is not None:\n        assert isinstance(padding_size, int), \"padding size is not given or is not an integer\"\n        if padding_size == 0:\n            return x\n        x = _unpad_tensor(x, unpad_dim, padding_size)\n    return x\n\n\ndef ulysses_pad(input_ids_rmpad: Tensor, position_ids_rmpad: Optional[Tensor] = None, sp_size: int = 1) -> tuple[Tensor, Optional[Tensor], int]:\n    \"\"\"\n    A generalized function to pad input_ids and optional position_ids for Ulysses sequence parallelism.\n    compatible with both standard Language Models and Vision-Language Models (VLMs) that might\n    have special position_id shapes.\n\n    Args:\n        input_ids_rmpad (Tensor): The unpadded input IDs tensor.\n        position_ids_rmpad (Optional[Tensor]): The optional unpadded position IDs tensor.\n        sp_size (int): The world size of the sequence parallelism group.\n\n    Returns:\n        A tuple containing:\n        - The padded input_ids tensor.\n        - The padded position_ids tensor (or None if not provided).\n        - The padding size applied.\n    \"\"\"\n    # If sequence parallelism is not enabled (size is 1 or less), do nothing and return immediately.\n    if sp_size <= 1:\n        return input_ids_rmpad, position_ids_rmpad, 0\n\n    total_seq_len = input_ids_rmpad.size(-1)\n\n    # If position_ids are provided, validate that their sequence dimension (the last one) matches the input_ids.\n    if position_ids_rmpad is not None:\n        assert input_ids_rmpad.size(-1) == position_ids_rmpad.size(-1), f\"Sequence dimension mismatch between input_ids and position_ids: {input_ids_rmpad.shape} vs {position_ids_rmpad.shape}\"\n\n    # Calculate the required padding size to make the sequence length divisible by sp_size.\n    pad_size = (sp_size - total_seq_len % sp_size) % sp_size\n\n    if pad_size > 0:\n        # pads the last dimension with `pad_size` elements, with a value of 0.\n        input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=0)\n\n        if position_ids_rmpad is not None:\n            position_ids_rmpad = torch.nn.functional.pad(position_ids_rmpad, (0, pad_size), value=0)\n\n    return input_ids_rmpad, position_ids_rmpad, pad_size\n\n\ndef ulysses_pad_and_slice_inputs(input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1):\n    \"\"\"\n    Pad and slice input_ids to be divisible by sp_size\n    Pad position_ids to be divisible by sp_size.\n\n    Note both input_ids_rmpad and position_ids_rmpad will be padded and sliced.\n\n    The is the utility of pre-forward for ulysses sequence parallelism\n\n    Args:\n        input_ids_rmpad: shape of [bsz, seqlen]\n        position_ids_rmpad: shape of [bsz, seqlen], where bsz must be 1\n        sp_size (int): ulysses sequence parallelism size\n\n    Returns:\n        torch.Tensor: padded and sliced input_ids\n        torch.Tensor: padded and sliced position_ids\n        int: pad size\n    \"\"\"\n    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad(input_ids_rmpad, position_ids_rmpad, sp_size)\n    input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False)\n    if position_ids_rmpad is not None:\n        position_ids_rmpad = slice_input_tensor(position_ids_rmpad, dim=1, padding=False)\n    return input_ids_rmpad, position_ids_rmpad, pad_size\n\n\ndef validate_ulysses_config(num_heads, ulysses_sequence_size):\n    if ulysses_sequence_size > 1:\n        assert num_heads % ulysses_sequence_size == 0, f\"num_heads ({num_heads}) must be divisible by ulysses sequence size({ulysses_sequence_size})\"\n"
  },
  {
    "path": "siirl/utils/model_utils/vllm_utils.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# To support different vLLM versions, we add the model into SUPPORTED_MOE_MODELS separately to avoid triggering unsupported issues.\nSUPPORTED_MOE_MODELS = []\n\ntry:\n    from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM, DeepseekV3ForCausalLM\n\n    SUPPORTED_MOE_MODELS.append(DeepseekV2ForCausalLM)\n    SUPPORTED_MOE_MODELS.append(DeepseekV3ForCausalLM)\nexcept ImportError:\n    pass\n\ntry:\n    from vllm.model_executor.models.mixtral import MixtralForCausalLM\n\n    SUPPORTED_MOE_MODELS.append(MixtralForCausalLM)\nexcept ImportError:\n    pass\n\ntry:\n    from vllm.model_executor.models.qwen2_moe import Qwen2MoeForCausalLM\n\n    SUPPORTED_MOE_MODELS.append(Qwen2MoeForCausalLM)\nexcept ImportError:\n    pass\n\ntry:\n    from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM\n\n    SUPPORTED_MOE_MODELS.append(Qwen3MoeForCausalLM)\nexcept ImportError:\n    pass\n\ntry:\n    from vllm.model_executor.models.kimi_vl import KimiVLForConditionalGeneration\n\n    SUPPORTED_MOE_MODELS.append(KimiVLForConditionalGeneration)\nexcept ImportError:\n    pass\n\nfrom typing import List\n\nfrom msgspec import field\nfrom packaging import version as vs\nfrom vllm.lora.models import LoRAModel\nfrom vllm.lora.request import LoRARequest\nfrom vllm.lora.utils import get_adapter_absolute_path\nfrom vllm.lora.worker_manager import LRUCacheWorkerLoRAManager\n\n\ndef patch_vllm_moe_model_weight_loader(model):\n    # this is a work around to load the weight of vllm fused moe model\n    # it is from a bug from vllm 0.8.2\n    # all the weights are supposed to have a weight_loader, but the moe weights\n    # do not have a weight_loader, so we need to patch it\n    # (True, 'model.embed_tokens.weight')\n    # (True, 'model.layers.0.self_attn.qkv_proj.weight')\n    # (True, 'model.layers.0.self_attn.qkv_proj.bias')\n    # (True, 'model.layers.0.self_attn.o_proj.weight')\n    # (True, 'model.layers.0.mlp.gate.weight')\n    # (True, 'model.layers.0.mlp.shared_expert.gate_up_proj.weight')\n    # (True, 'model.layers.0.mlp.shared_expert.down_proj.weight')\n    # (False, 'model.layers.0.mlp.shared_expert_gate.weight')   use default\n    # (False, 'model.layers.0.input_layernorm.weight')          use default\n    # (False, 'model.layers.0.post_attention_layernorm.weight') use default\n    # (False, 'model.layers.0.mlp.experts.w13_weight')          use mlp.experts.weight_loader\n    # (False, 'model.layers.0.mlp.experts.w2_weight')          use mlp.experts.weight_loader\n\n    # Define MLP attribute mapping for different model types\n    MLP_ATTR_MAPPING = {\n        MixtralForCausalLM: \"block_sparse_moe\",\n    }\n    DEFAULT_MLP_ATTR = \"mlp\"\n\n    if not isinstance(model, tuple(SUPPORTED_MOE_MODELS)):\n        return\n\n    model = getattr(model, \"model\", None) or getattr(model, \"language_model\", None)\n    if model is None:\n        raise ValueError(\"The provided model does not have a valid 'model' or 'language_model' attribute.\")\n\n    for layer in model.layers:\n        mlp_attr = MLP_ATTR_MAPPING.get(type(model), DEFAULT_MLP_ATTR)\n        mlp = getattr(layer, mlp_attr)\n\n        param_dict = dict(mlp.named_parameters())\n        for name, param in param_dict.items():\n            if \"w13_weight\" in name or \"w2_weight\" in name:\n                param.weight_loader = mlp.experts.weight_loader\n\n\nclass TensorLoRARequest(LoRARequest):\n    peft_config: dict = field(default=None)\n    lora_tensors: dict = field(default=None)\n\n\nclass VLLMHijack:\n    @staticmethod\n    def hijack():\n        def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel:\n            \"\"\"\n            based on vllm.lora.worker_manager.WorkerLoRAManager._load_adapter, support load adapter with lora tensors\n\n            Reason:\n            VLLM does not support adding LoRA from tensors directly. It only supports adding LoRA via file paths.\n            To synchronize the LoRA tensors of the actor model, we need to find a workaround to enable VLLM to load memory-based LoRA tensors.\n            \"\"\"\n            try:\n                supported_lora_modules = self._adapter_manager.supported_lora_modules\n                packed_modules_mapping = self._adapter_manager.packed_modules_mapping\n                expected_lora_modules: List[str] = []\n                for module in supported_lora_modules:\n                    if module in packed_modules_mapping:\n                        expected_lora_modules.extend(packed_modules_mapping[module])\n                    else:\n                        expected_lora_modules.append(module)\n\n                expected_lora_modules = list(set(expected_lora_modules))\n\n                lora_tensors = None\n                from vllm.lora.peft_helper import PEFTHelper\n\n                if isinstance(lora_request, TensorLoRARequest):\n                    peft_config = lora_request.peft_config\n                    lora_tensors = lora_request.lora_tensors\n                    peft_helper = PEFTHelper.from_dict(peft_config)\n                else:\n                    lora_path = get_adapter_absolute_path(lora_request.lora_path)\n\n                    peft_helper = PEFTHelper.from_local_dir(lora_path, self.max_position_embeddings)\n\n                # Validates the LoRA configuration against requirements before\n                # loading weights, throwing an exception if validation fails.\n                peft_helper.validate_legal(self.lora_config)\n\n                # For some models like Qwen2VL, we need to use hf_to_vllm_mapper\n                # to ensure correct loading of lora weights.\n                model = self._adapter_manager.model\n                hf_to_vllm_mapper = None\n                if hasattr(model, \"hf_to_vllm_mapper\") and model.hf_to_vllm_mapper is not None:\n                    hf_to_vllm_mapper = model.hf_to_vllm_mapper\n\n                if isinstance(lora_request, TensorLoRARequest):\n                    lora = self._lora_model_cls.from_lora_tensors(\n                        lora_model_id=lora_request.lora_int_id,\n                        tensors=lora_tensors,\n                        peft_helper=peft_helper,\n                        device=\"cpu\",\n                        dtype=self.lora_config.lora_dtype,\n                        embeddings=None,\n                        target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size,\n                        embedding_modules=self.embedding_modules,\n                        embedding_padding_modules=self.embedding_padding_modules,\n                        weights_mapper=hf_to_vllm_mapper,\n                    )\n                else:\n                    lora = self._lora_model_cls.from_local_checkpoint(\n                        lora_path,\n                        expected_lora_modules,\n                        peft_helper=peft_helper,\n                        lora_model_id=lora_request.lora_int_id,\n                        device=\"cpu\",\n                        dtype=self.lora_config.lora_dtype,\n                        target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size,\n                        embedding_modules=self.embedding_modules,\n                        embedding_padding_modules=self.embedding_padding_modules,\n                        weights_mapper=hf_to_vllm_mapper,\n                    )\n            except Exception as e:\n                raise e\n\n            if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:\n                raise ValueError(f\"LoRA added vocab size {lora.extra_vocab_size} is greater than lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}.\")\n            return lora\n\n        def do_hijack(target_cls, target_method_name, hooking_method):\n            setattr(target_cls, target_method_name, hooking_method)\n\n        do_hijack(LRUCacheWorkerLoRAManager, \"_load_adapter\", hijack__load_adapter)\n"
  },
  {
    "path": "siirl/utils/reward_score/__init__.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# from . import gsm8k, math, prime_math, prime_code\n\nfrom siirl.utils.extras.import_utils import deprecated\n\n\ndef default_compute_score(data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None):\n    \"\"\"Compute the score for a given solution based on the data source.\n\n    Args:\n        data_source (str): The source dataset identifier which determines the scoring method.\n        solution_str (str): The solution string to be evaluated.\n        ground_truth (str): The ground truth answer for comparison.\n        extra_info (dict, optional): Additional information that might be needed for scoring. Defaults to None.\n\n    Returns:\n        float: The computed score as a floating point number. If the result is a dictionary,\n               it returns the dictionary instead.\n\n    Raises:\n        NotImplementedError: If the reward function is not implemented for the given data source.\n    \"\"\"\n    if data_source == \"openai/gsm8k\":\n        from . import gsm8k\n\n        res = gsm8k.compute_score(solution_str, ground_truth)\n    elif data_source in [\"lighteval/MATH\", \"DigitalLearningGmbH/MATH-lighteval\", \"agentica-org/DeepScaleR-Preview-Dataset\", \"AIME2024\", \"AIME2025\", \"AIME24\", \"AIME25\"]:\n        from . import math\n\n        res = math.compute_score(solution_str, ground_truth)\n        # [Optional] Math-Verify Integration\n        # For enhanced accuracy, consider utilizing Math-Verify (https://github.com/huggingface/Math-Verify).\n        # Note: Math-Verify needs to be manually installed via pip: `pip install math-verify`.\n        # To use it, override the `compute_score` function with the following implementation:\n\n        # from . import math_verify\n        # res = math_verify.compute_score(solution_str, ground_truth)\n    elif data_source == \"math_dapo\" or data_source.startswith(\"aime\"):\n        from . import math_dapo\n\n        res = math_dapo.compute_score(solution_str, ground_truth)\n    elif data_source in [\n        \"numina_aops_forum\",\n        \"numina_synthetic_math\",\n        \"numina_amc_aime\",\n        \"numina_synthetic_amc\",\n        \"numina_cn_k12\",\n        \"numina_olympiads\",\n    ]:\n        from . import prime_math\n\n        res = prime_math.compute_score(solution_str, ground_truth)\n    elif data_source in [\"codecontests\", \"apps\", \"codeforces\", \"taco\"]:\n        # Use the passed sandbox_fusion_url if available\n        if sandbox_fusion_url:\n            from . import sandbox_fusion\n\n            # Pass the URL directly, ground_truth likely contains test cases here\n            res = sandbox_fusion.compute_score(sandbox_fusion_url, concurrent_semaphore, solution_str, ground_truth, continuous=True)\n        else:\n            # If no sandbox URL is provided, fall back to prime_code or raise error\n            from . import prime_code\n\n            # Assuming prime_code doesn't need the URL\n            res = prime_code.compute_score(solution_str, ground_truth, continuous=True)\n    elif data_source in [\"hiyouga/geometry3k\"]:\n        from . import geo3k\n\n        res = geo3k.compute_score(solution_str, ground_truth)\n    elif data_source in [\"mm_eureka\"]:\n        from . import mm_eureka\n\n        res = mm_eureka.compute_score(solution_str, ground_truth)\n    elif data_source in [\"searchR1_nq\", \"searchR1_triviaqa\", \"searchR1_popqa\", \"searchR1_hotpotqa\", \"searchR1_2wikimultihopqa\", \"searchR1_musique\", \"searchR1_bamboogle\"]:\n        from . import search_r1_like_qa_em\n\n        res = search_r1_like_qa_em.compute_score(solution_str, ground_truth)\n    else:\n        raise NotImplementedError(f\"Reward function is not implemented for {data_source=}\")\n\n    if isinstance(res, dict):\n        return res\n    elif isinstance(res, (int, float, bool)):\n        return float(res)\n    else:\n        return float(res[0])\n\n\n@deprecated(\"siirl.utils.reward_score.default_compute_score\")\ndef _default_compute_score(data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None):\n    \"\"\"\n    Legacy function API to be deprecated. Please use `default_compute_score` instead.\n    \"\"\"\n    return default_compute_score(data_source, solution_str, ground_truth, extra_info, sandbox_fusion_url, concurrent_semaphore)\n\n\n__all__ = [\"default_compute_score\"]\n"
  },
  {
    "path": "siirl/utils/reward_score/embodied.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport re\nfrom typing import Any, Dict, List, Tuple\nfrom loguru import logger\nfrom tensordict import TensorDict\n# Handle different tensordict versions - NonTensorData location varies\ntry:\n    from tensordict import NonTensorData\nexcept ImportError:\n    from tensordict.tensorclass import NonTensorData\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nfrom scipy import special\nfrom scipy.spatial.distance import cdist\nfrom sklearn.cluster import DBSCAN\nfrom sklearn.preprocessing import StandardScaler\n\n\ndef _tensor_to_str_list(tensor: torch.Tensor) -> List[str]:\n    \"\"\"Helper function to decode a byte tensor into a list of strings.\"\"\"\n    if tensor.ndim == 0:\n        tensor = tensor.unsqueeze(0)\n    byte_array = tensor.cpu().numpy()\n    return [bytes(x).decode(\"utf-8\", errors=\"ignore\").rstrip(\"\\0\") for x in byte_array]\n\n\ndef _extract_task_name(task_file_name: str) -> str:\n    \"\"\"Helper function to parse the base task name from a trial file name.\"\"\"\n    match = re.match(r\"(libero_\\w+_task_\\d+)_trial_\\d+\", task_file_name)\n    return match.group(1) if match else task_file_name\n\n\ndef _compute_cluster_centers(embeddings: np.ndarray, eps: float = 0.5, min_samples: int = 2) -> np.ndarray:\n    \"\"\"Compute cluster centers using DBSCAN clustering.\"\"\"\n    if len(embeddings) == 0:\n        return np.array([])\n    \n    scaler = StandardScaler()\n    scaled_embeddings = scaler.fit_transform(embeddings)\n    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(scaled_embeddings)\n    \n    cluster_centers = []\n    unique_labels = set(clustering.labels_) - {-1}  # Exclude noise label\n    \n    for label in unique_labels:\n        cluster_points = scaled_embeddings[clustering.labels_ == label]\n        center = scaler.inverse_transform(cluster_points.mean(axis=0, keepdims=True)).flatten()\n        cluster_centers.append(center)\n    \n    # Fallback to mean if no clusters found\n    if not cluster_centers:\n        cluster_centers = [embeddings.mean(axis=0)]\n    \n    return np.array(cluster_centers)\n\n\ndef _get_batch_size(batch_data: TensorDict) -> int:\n    \"\"\"Best-effort batch size extraction from a TensorDict.\"\"\"\n    try:\n        if hasattr(batch_data, \"batch_size\") and batch_data.batch_size is not None:\n            batch_size = batch_data.batch_size\n            if isinstance(batch_size, (tuple, list, torch.Size)):\n                return int(batch_size[0]) if len(batch_size) > 0 else 0\n            return int(batch_size)\n    except Exception:\n        pass\n    for key in (\"responses\", \"input_ids\", \"attention_mask\", \"response_mask\", \"pixel_values\"):\n        if key in batch_data:\n            try:\n                return int(batch_data[key].size(0))\n            except Exception:\n                continue\n    return 0\n\n\ndef _extract_local_data(batch_data: TensorDict) -> Dict[str, Any]:\n    \"\"\"Extract local data from batch for reward computation.\"\"\"\n    batch_size = _get_batch_size(batch_data)\n    \n    # Ensure all required fields are present\n    required_fields = [\"complete\", \"vjepa_embedding\", \"task_file_name\", \"finish_step\"]\n    for field in required_fields:\n        if field not in batch_data:\n            raise KeyError(f\"Critical data '{field}' missing from batch in reward computation.\")\n\n    # Extract data\n    completes = np.array(batch_data[\"complete\"].tolist())\n    embeddings = batch_data[\"vjepa_embedding\"].cpu().numpy()\n    finish_steps = batch_data[\"finish_step\"].cpu().numpy()\n    \n    task_file_names = _tensor_to_str_list(batch_data[\"task_file_name\"])\n    task_names = np.array([_extract_task_name(name) for name in task_file_names])\n    \n    # Pre-filter zero embeddings\n    zero_mask = np.all(embeddings == 0, axis=1)\n    valid_mask = ~zero_mask\n    \n    return {\n        \"batch_size\": batch_size,\n        \"embeddings\": embeddings,\n        \"completes\": completes,\n        \"finish_steps\": finish_steps,\n        \"task_names\": task_names,\n        \"valid_mask\": valid_mask,\n        \"zero_mask\": zero_mask,\n    }\n\n\ndef _gather_all_data(\n    local_data: Dict[str, Any],\n    dp_size: int,\n    dp_rank: int,\n    tp_size: int,\n    is_representative: bool\n) -> Tuple[Dict[str, Any], List[int]]:\n    \"\"\"\n    One-shot gather of all data from all DP ranks.\n\n    All ranks in the world must participate in all_gather_object.\n    Only representative ranks (tp_rank==0, last pp stage) send actual data.\n\n    Args:\n        local_data: Local data dict\n        dp_size: Number of DP ranks\n        dp_rank: Current DP rank\n        tp_size: Number of TP ranks (used to calculate representative rank indices)\n        is_representative: Whether this rank is a representative rank\n\n    Returns:\n        Tuple of (global_data, batch_sizes_per_rank)\n    \"\"\"\n    # Prepare data - only representative ranks send actual data, include dp_rank for ordering\n    if is_representative:\n        send_data = {\n            \"dp_rank\": dp_rank,\n            \"batch_size\": local_data[\"batch_size\"],\n            \"embeddings\": local_data[\"embeddings\"].tolist(),\n            \"completes\": local_data[\"completes\"].tolist(),\n            \"task_names\": local_data[\"task_names\"].tolist(),\n            \"valid_mask\": local_data[\"valid_mask\"].tolist(),\n        }\n    else:\n        send_data = None\n\n    # All ranks must participate in all_gather_object\n    world_size = dist.get_world_size() if dist.is_initialized() else 1\n    all_data = [None] * world_size\n    dist.all_gather_object(all_data, send_data)\n\n    # Filter to get only data from representative ranks and sort by dp_rank\n    dp_data = [d for d in all_data if d is not None and isinstance(d, dict) and \"dp_rank\" in d]\n    dp_data = sorted(dp_data, key=lambda x: x[\"dp_rank\"])\n\n    # Extract batch_sizes and merge data\n    batch_sizes = [d[\"batch_size\"] for d in dp_data]\n    global_embeddings = np.concatenate([np.array(d[\"embeddings\"]) for d in dp_data], axis=0)\n    global_completes = np.concatenate([np.array(d[\"completes\"]) for d in dp_data], axis=0)\n    global_task_names = np.concatenate([np.array(d[\"task_names\"]) for d in dp_data], axis=0)\n    global_valid_mask = np.concatenate([np.array(d[\"valid_mask\"]) for d in dp_data], axis=0)\n\n    global_data = {\n        \"embeddings\": global_embeddings,\n        \"completes\": global_completes,\n        \"task_names\": global_task_names,\n        \"valid_mask\": global_valid_mask,\n        \"batch_size\": len(global_embeddings),\n    }\n\n    return global_data, batch_sizes\n\n\ndef _compute_all_rewards(data: Dict[str, Any], logger) -> np.ndarray:\n    \"\"\"\n    Compute rewards for all samples based on global data.\n    \n    Reward logic:\n    - Success samples: reward = 1.0\n    - Failed samples with valid embeddings: reward = sigmoid-shaped based on distance to success cluster centers\n    - Invalid samples (zero embeddings): reward = 0.0\n    \"\"\"\n    batch_size = data[\"batch_size\"]\n    embeddings = data[\"embeddings\"]\n    completes = data[\"completes\"].astype(bool)\n    task_names = data[\"task_names\"]\n    valid_mask = data[\"valid_mask\"]\n    \n    final_rewards = np.zeros(batch_size, dtype=float)\n    \n    # Success + valid -> reward = 1.0\n    success_mask = completes & valid_mask\n    final_rewards[success_mask] = 1.0\n    \n    # Failed + valid -> reward shaping\n    fail_mask = ~completes & valid_mask\n\n    if not fail_mask.any():\n        return final_rewards\n\n    # Group by task and compute rewards\n    unique_tasks = np.unique(task_names)\n\n    for task in unique_tasks:\n        task_mask = task_names == task\n        task_success_mask = task_mask & success_mask\n        task_fail_mask = task_mask & fail_mask\n\n        success_count = task_success_mask.sum()\n        fail_count = task_fail_mask.sum()\n\n        if success_count == 0 or fail_count == 0:\n            continue\n\n        # Get embeddings\n        success_emb = embeddings[task_success_mask]\n        fail_emb = embeddings[task_fail_mask]\n        fail_indices = np.where(task_fail_mask)[0]\n\n        # Compute cluster centers from success embeddings\n        cluster_centers = _compute_cluster_centers(success_emb)\n\n        # Compute distances from failed samples to nearest cluster center\n        distance_matrix = cdist(fail_emb, cluster_centers, \"euclidean\")\n        min_distances = distance_matrix.min(axis=1)\n\n        # Normalize distances\n        min_dist, max_dist = min_distances.min(), min_distances.max()\n        dist_range = max_dist - min_dist\n\n        if dist_range < 1e-6:\n            normalized_dists = np.full_like(min_distances, 0.5)\n        else:\n            normalized_dists = (min_distances - min_dist) / dist_range\n\n        # Sigmoid mapping: closer to success -> higher reward (max 0.6)\n        sigmoid_steepness = 10.0\n        sigmoid_offset = 0.5\n        sigmoid_inputs = sigmoid_steepness * (sigmoid_offset - normalized_dists)\n        reward_values = 0.6 * special.expit(sigmoid_inputs)\n\n        final_rewards[fail_indices] = reward_values\n\n    return final_rewards\n\n\ndef _build_results(\n    global_rewards: np.ndarray,\n    local_data: Dict[str, Any],\n    dp_rank: int,\n    batch_sizes: List[int]\n) -> List[Dict[str, Any]]:\n    \"\"\"Build result list for local samples only.\"\"\"\n    local_batch_size = local_data[\"batch_size\"]\n    \n    # Calculate slice based on pre-gathered batch_sizes (no extra gather needed)\n    start_idx = sum(batch_sizes[:dp_rank])\n    end_idx = start_idx + local_batch_size\n    local_rewards = global_rewards[start_idx:end_idx]\n    \n    # Build results\n    results = []\n    for i in range(local_batch_size):\n        results.append({\n            \"is_success\": bool(local_data[\"completes\"][i]),\n            \"task_name\": local_data[\"task_names\"][i],\n            \"format_correctness\": 1.0,\n            \"is_zero_embedding\": bool(local_data[\"zero_mask\"][i]),\n            \"score\": float(local_rewards[i]),\n        })\n    \n    return results\n\n\ndef compute_embodied_reward(\n    batch_data: TensorDict,\n    compute_only_rank_0: bool = True,\n    **kwargs: Any,\n) -> List[Dict[str, Any]]:\n    \"\"\"\n    Computes rewards based on VJEPA embeddings and task completion status.\n\n    Distributed-aware: uses one-shot all_gather to collect data from all DP ranks.\n\n    Optimization: When compute_only_rank_0=True (default), only rank 0 computes rewards\n    and broadcasts to other ranks, eliminating redundant computation.\n\n    Args:\n        batch_data: TensorDict containing batch information with parallelism info:\n            - dp_size, dp_rank: Data Parallel info\n            - tp_rank, tp_size: Tensor Parallel info\n            - pp_rank, pp_size: Pipeline Parallel info\n        compute_only_rank_0: If True, only rank 0 computes rewards (default: True)\n\n    Returns:\n        A list of dictionaries, each containing detailed score information.\n    \"\"\"\n\n    # === Step 1: Extract parallelism info from batch ===\n    def get_nontensor_value(key, default):\n        val = batch_data.get(key, None)\n        if val is None:\n            return default\n        return val.data if isinstance(val, NonTensorData) else val\n\n    dp_size = get_nontensor_value(\"dp_size\", 1)\n    dp_rank = get_nontensor_value(\"dp_rank\", 0)\n    tp_rank = get_nontensor_value(\"tp_rank\", 0)\n    tp_size = get_nontensor_value(\"tp_size\", 1)\n    pp_rank = get_nontensor_value(\"pp_rank\", 0)\n    pp_size = get_nontensor_value(\"pp_size\", 1)\n\n    # === Step 2: Extract local data ===\n    local_data = _extract_local_data(batch_data)\n    batch_size = local_data[\"batch_size\"]\n\n    # === Step 3: Determine gather requirements ===\n    is_representative = (tp_rank == 0) and (pp_rank == pp_size - 1)\n    need_distributed = dp_size > 1 and dist.is_initialized()\n\n    # === Step 4: Gather all data (one-shot) or use local ===\n    if need_distributed:\n        global_data, batch_sizes = _gather_all_data(local_data, dp_size, dp_rank, tp_size, is_representative)\n    else:\n        global_data = local_data\n        batch_sizes = [batch_size]\n\n    # === Step 5: Compute rewards ===\n    if compute_only_rank_0 and need_distributed:\n        # Optimized path: only rank 0 computes, then broadcast\n        if dp_rank == 0:\n            global_rewards = _compute_all_rewards(global_data, logger)\n            rewards_tensor = torch.tensor(global_rewards, dtype=torch.float32).cuda()\n        else:\n            rewards_tensor = torch.zeros(global_data['batch_size'], dtype=torch.float32).cuda()\n\n        # Broadcast from rank 0 to all DP ranks\n        dist.broadcast(rewards_tensor, src=0)\n        global_rewards = rewards_tensor.cpu().numpy()\n    else:\n        # Original path: all ranks compute (for backward compatibility or single-process)\n        global_rewards = _compute_all_rewards(global_data, logger)\n\n    # === Step 6: Build results for local samples ===\n    if need_distributed:\n        results = _build_results(global_rewards, local_data, dp_rank, batch_sizes)\n    else:\n        results = []\n        for i in range(batch_size):\n            results.append({\n                \"is_success\": bool(local_data[\"completes\"][i]),\n                \"task_name\": local_data[\"task_names\"][i],\n                \"format_correctness\": 1.0,\n                \"is_zero_embedding\": bool(local_data[\"zero_mask\"][i]),\n                \"score\": float(global_rewards[i]),\n            })\n\n    # === Step 7: Log final statistics (only rank 0) ===\n    if dp_rank == 0:\n        local_rewards = np.array([r[\"score\"] for r in results])\n        num_success = (local_rewards == 1.0).sum()\n        num_partial = ((local_rewards > 0) & (local_rewards < 1.0)).sum()\n        num_failed = (local_rewards == 0).sum()\n\n        logger.info(f\"[REWARD COMPUTE] Completed - \"\n                   f\"Avg: {local_rewards.mean():.4f}, Min: {local_rewards.min():.4f}, Max: {local_rewards.max():.4f}, \"\n                   f\"Success(1.0): {num_success}, Partial(0<r<1): {num_partial}, Failed(0): {num_failed}\")\n\n    return results\n"
  },
  {
    "path": "siirl/utils/reward_score/geo3k.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport re\n\nfrom mathruler.grader import extract_boxed_content, grade_answer\nimport os\n\nLOG_PATH = os.environ.get(\"REWARD_LOG_PATH\", \"reward.log\")\n\n\ndef format_reward(predict_str: str) -> float:\n    pattern = re.compile(r\"<think>.*</think>.*\\\\boxed\\{.*\\}.*\", re.DOTALL)\n    match_result = re.fullmatch(pattern, predict_str)\n    return 1.0 if match_result else 0.0\n\n\ndef acc_reward(predict_str: str, ground_truth: str) -> float:\n    answer = extract_boxed_content(predict_str)\n    return 1.0 if grade_answer(answer, ground_truth) else 0.0\n\n\ndef format_reward(predict_str: str) -> float:\n    pattern = re.compile(r\"<think>.*</think>.*\\\\boxed\\{.*\\}.*\", re.DOTALL)\n    match_result = re.fullmatch(pattern, predict_str)\n    return 1.0 if match_result else 0.0\n\n\ndef acc_reward(predict_str: str, ground_truth: str) -> float:\n    answer = extract_boxed_content(predict_str)\n    return 1.0 if grade_answer(answer, ground_truth) else 0.0\n\n\ndef compute_score(predict_str: str, ground_truth: str) -> float:\n    return 0.9 * acc_reward(predict_str, ground_truth) + 0.1 * format_reward(predict_str)\n"
  },
  {
    "path": "siirl/utils/reward_score/gsm8k.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport re\n\n\ndef extract_solution(solution_str, method=\"strict\"):\n    assert method in [\"strict\", \"flexible\"]\n\n    if method == \"strict\":\n        # this also tests the formatting of the model\n        solution = re.search(\"#### (\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n        if solution is None:\n            final_answer = None\n        else:\n            final_answer = solution.group(0)\n            final_answer = final_answer.split(\"#### \")[1].replace(\",\", \"\").replace(\"$\", \"\")\n    elif method == \"flexible\":\n        answer = re.findall(\"(\\\\-?[0-9\\\\.\\\\,]+)\", solution_str)\n        final_answer = None\n        if len(answer) == 0:\n            # no reward is there is no answer\n            pass\n        else:\n            invalid_str = [\"\", \".\"]\n            # find the last number that is not '.'\n            for final_answer in reversed(answer):\n                if final_answer not in invalid_str:\n                    break\n    return final_answer\n\n\ndef compute_score(solution_str, ground_truth, method=\"strict\", format_score=0.0, score=1.0):\n    \"\"\"The scoring function for GSM8k.\n\n    Reference: Trung, Luong, et al. \"Reft: Reasoning with reinforced fine-tuning.\" Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024.\n\n    Args:\n        solution_str: the solution text\n        ground_truth: the ground truth\n        method: the method to extract the solution, choices are 'strict' and 'flexible'\n        format_score: the score for the format\n        score: the score for the correct answer\n    \"\"\"\n    answer = extract_solution(solution_str=solution_str, method=method)\n    if answer is None:\n        return 0\n    else:\n        if answer == ground_truth:\n            return score\n        else:\n            return format_score\n"
  },
  {
    "path": "siirl/utils/reward_score/math.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py\n\n\ndef compute_score(solution_str, ground_truth) -> float:\n    retval = 0.0\n    try:\n        string_in_last_boxed = last_boxed_only_string(solution_str)\n        if string_in_last_boxed is not None:\n            answer = remove_boxed(string_in_last_boxed)\n            if is_equiv(answer, ground_truth):\n                retval = 1.0\n    except Exception as e:\n        print(e)\n\n    return retval\n\n\n# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py\ndef is_equiv(str1, str2, verbose=False):\n    if str1 is None and str2 is None:\n        print(\"WARNING: Both None\")\n        return True\n    if str1 is None or str2 is None:\n        return False\n\n    try:\n        ss1 = strip_string(str1)\n        ss2 = strip_string(str2)\n        if verbose:\n            print(ss1, ss2)\n        return ss1 == ss2\n    except Exception:\n        return str1 == str2\n\n\ndef remove_boxed(s):\n    if \"\\\\boxed \" in s:\n        left = \"\\\\boxed \"\n        assert s[: len(left)] == left\n        return s[len(left) :]\n\n    left = \"\\\\boxed{\"\n\n    assert s[: len(left)] == left\n    assert s[-1] == \"}\"\n\n    return s[len(left) : -1]\n\n\ndef last_boxed_only_string(string):\n    idx = string.rfind(\"\\\\boxed\")\n    if \"\\\\boxed \" in string:\n        return \"\\\\boxed \" + string.split(\"\\\\boxed \")[-1].split(\"$\")[0]\n    if idx < 0:\n        idx = string.rfind(\"\\\\fbox\")\n        if idx < 0:\n            return None\n\n    i = idx\n    right_brace_idx = None\n    num_left_braces_open = 0\n    while i < len(string):\n        if string[i] == \"{\":\n            num_left_braces_open += 1\n        if string[i] == \"}\":\n            num_left_braces_open -= 1\n            if num_left_braces_open == 0:\n                right_brace_idx = i\n                break\n        i += 1\n\n    retval = None if right_brace_idx is None else string[idx : right_brace_idx + 1]\n\n    return retval\n\n\ndef fix_fracs(string):\n    substrs = string.split(\"\\\\frac\")\n    new_str = substrs[0]\n    if len(substrs) > 1:\n        substrs = substrs[1:]\n        for substr in substrs:\n            new_str += \"\\\\frac\"\n            if substr[0] == \"{\":\n                new_str += substr\n            else:\n                try:\n                    assert len(substr) >= 2\n                except:  # noqa: E722\n                    return string\n                a = substr[0]\n                b = substr[1]\n                if b != \"{\":\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}{\" + b + \"}\" + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}{\" + b + \"}\"\n                else:\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}\" + b + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}\" + b\n    string = new_str\n    return string\n\n\ndef fix_a_slash_b(string):\n    if len(string.split(\"/\")) != 2:\n        return string\n    a = string.split(\"/\")[0]\n    b = string.split(\"/\")[1]\n    try:\n        a = int(a)\n        b = int(b)\n        assert string == \"{}/{}\".format(a, b)\n        new_string = \"\\\\frac{\" + str(a) + \"}{\" + str(b) + \"}\"\n        return new_string\n    except:  # noqa: E722\n        return string\n\n\ndef remove_right_units(string):\n    # \"\\\\text{ \" only ever occurs (at least in the val set) when describing units\n    if \"\\\\text{ \" in string:\n        splits = string.split(\"\\\\text{ \")\n        assert len(splits) == 2\n        return splits[0]\n    else:\n        return string\n\n\ndef fix_sqrt(string):\n    if \"\\\\sqrt\" not in string:\n        return string\n    splits = string.split(\"\\\\sqrt\")\n    new_string = splits[0]\n    for split in splits[1:]:\n        if split[0] != \"{\":\n            a = split[0]\n            new_substr = \"\\\\sqrt{\" + a + \"}\" + split[1:]\n        else:\n            new_substr = \"\\\\sqrt\" + split\n        new_string += new_substr\n    return new_string\n\n\ndef strip_string(string):\n    # linebreaks\n    string = string.replace(\"\\n\", \"\")\n\n    # remove inverse spaces\n    string = string.replace(\"\\\\!\", \"\")\n\n    # replace \\\\ with \\\n    string = string.replace(\"\\\\\\\\\", \"\\\\\")\n\n    # replace tfrac and dfrac with frac\n    string = string.replace(\"tfrac\", \"frac\")\n    string = string.replace(\"dfrac\", \"frac\")\n\n    # remove \\left and \\right\n    string = string.replace(\"\\\\left\", \"\")\n    string = string.replace(\"\\\\right\", \"\")\n\n    # Remove circ (degrees)\n    string = string.replace(\"^{\\\\circ}\", \"\")\n    string = string.replace(\"^\\\\circ\", \"\")\n\n    # remove dollar signs\n    string = string.replace(\"\\\\$\", \"\")\n\n    # remove units (on the right)\n    string = remove_right_units(string)\n\n    # remove percentage\n    string = string.replace(\"\\\\%\", \"\")\n    string = string.replace(\"\\%\", \"\")  # noqa: W605\n\n    # \" 0.\" equivalent to \" .\" and \"{0.\" equivalent to \"{.\" Alternatively, add \"0\" if \".\" is the start of the string\n    string = string.replace(\" .\", \" 0.\")\n    string = string.replace(\"{.\", \"{0.\")\n    # if empty, return empty string\n    if len(string) == 0:\n        return string\n    if string[0] == \".\":\n        string = \"0\" + string\n\n    # to consider: get rid of e.g. \"k = \" or \"q = \" at beginning\n    if len(string.split(\"=\")) == 2 and len(string.split(\"=\")[0]) <= 2:\n        string = string.split(\"=\")[1]\n\n    # fix sqrt3 --> sqrt{3}\n    string = fix_sqrt(string)\n\n    # remove spaces\n    string = string.replace(\" \", \"\")\n\n    # \\frac1b or \\frac12 --> \\frac{1}{b} and \\frac{1}{2}, etc. Even works with \\frac1{72} (but not \\frac{72}1). Also does a/b --> \\\\frac{a}{b}\n    string = fix_fracs(string)\n\n    # manually change 0.5 --> \\frac{1}{2}\n    if string == \"0.5\":\n        string = \"\\\\frac{1}{2}\"\n\n    # NOTE: X/Y changed to \\frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y\n    string = fix_a_slash_b(string)\n\n    return string\n"
  },
  {
    "path": "siirl/utils/reward_score/math_batch.py",
    "content": "# Copyright 2025 Individual Contributor: Mert Unsal\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom siirl.utils.reward_score.math import compute_score\n\n\ndef compute_score_batched(data_sources, solution_strs, ground_truths, extra_infos):\n    \"\"\"\n    This is a demonstration of how the batched reward function should look like.\n    Typically, you want to use batched reward to speed up the process with parallelization\n    \"\"\"\n    return [compute_score(solution_str, ground_truth) for solution_str, ground_truth in zip(solution_strs, ground_truths)]\n"
  },
  {
    "path": "siirl/utils/reward_score/math_dapo.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py\n\nimport re\nfrom typing import Optional\n\n\ndef last_boxed_only_string(string: str) -> Optional[str]:\n    \"\"\"Extract the last LaTeX boxed expression from a string.\n\n    Args:\n        string: Input string containing LaTeX code\n\n    Returns:\n        The last boxed expression or None if not found\n    \"\"\"\n    idx = string.rfind(\"\\\\boxed{\")\n    if idx < 0:\n        return None\n\n    i = idx\n    right_brace_idx = None\n    num_left_braces_open = 0\n\n    while i < len(string):\n        if string[i] == \"{\":\n            num_left_braces_open += 1\n        if string[i] == \"}\":\n            num_left_braces_open -= 1\n            if num_left_braces_open == 0:\n                right_brace_idx = i\n                break\n        i += 1\n\n    return string[idx : right_brace_idx + 1] if right_brace_idx is not None else None\n\n\ndef remove_boxed(s: str) -> str:\n    \"\"\"Remove the LaTeX boxed command from a string.\n\n    Args:\n        s: String with format \"\\\\boxed{content}\"\n\n    Returns:\n        The content inside the boxed command\n    \"\"\"\n    left = \"\\\\boxed{\"\n    assert s[: len(left)] == left, f\"box error: {s}\"\n    assert s[-1] == \"}\", f\"box error: {s}\"\n    return s[len(left) : -1]\n\n\n# Constants for normalization\nSUBSTITUTIONS = [\n    (\"an \", \"\"),\n    (\"a \", \"\"),\n    (\".$\", \"$\"),\n    (\"\\\\$\", \"\"),\n    (r\"\\ \", \"\"),\n    (\" \", \"\"),\n    (\"mbox\", \"text\"),\n    (\",\\\\text{and}\", \",\"),\n    (\"\\\\text{and}\", \",\"),\n    (\"\\\\text{m}\", \"\\\\text{}\"),\n]\n\nREMOVED_EXPRESSIONS = [\n    \"square\",\n    \"ways\",\n    \"integers\",\n    \"dollars\",\n    \"mph\",\n    \"inches\",\n    \"hours\",\n    \"km\",\n    \"units\",\n    \"\\\\ldots\",\n    \"sue\",\n    \"points\",\n    \"feet\",\n    \"minutes\",\n    \"digits\",\n    \"cents\",\n    \"degrees\",\n    \"cm\",\n    \"gm\",\n    \"pounds\",\n    \"meters\",\n    \"meals\",\n    \"edges\",\n    \"students\",\n    \"childrentickets\",\n    \"multiples\",\n    \"\\\\text{s}\",\n    \"\\\\text{.}\",\n    \"\\\\text{\\ns}\",\n    \"\\\\text{}^2\",\n    \"\\\\text{}^3\",\n    \"\\\\text{\\n}\",\n    \"\\\\text{}\",\n    r\"\\mathrm{th}\",\n    r\"^\\circ\",\n    r\"^{\\circ}\",\n    r\"\\;\",\n    r\",\\!\",\n    \"{,}\",\n    '\"',\n    \"\\\\dots\",\n]\n\n\ndef normalize_final_answer(final_answer: str) -> str:\n    \"\"\"Normalize a final answer to a quantitative reasoning question.\n\n    Args:\n        final_answer: The answer string to normalize\n\n    Returns:\n        Normalized answer string\n    \"\"\"\n    final_answer = final_answer.split(\"=\")[-1]\n\n    # Apply substitutions and removals\n    for before, after in SUBSTITUTIONS:\n        final_answer = final_answer.replace(before, after)\n    for expr in REMOVED_EXPRESSIONS:\n        final_answer = final_answer.replace(expr, \"\")\n\n    # Extract and normalize LaTeX math\n    final_answer = re.sub(r\"(.*?)(\\$)(.*?)(\\$)(.*)\", \"$\\\\3$\", final_answer)\n    final_answer = re.sub(r\"(\\\\text\\{)(.*?)(\\})\", \"\\\\2\", final_answer)\n    final_answer = re.sub(r\"(\\\\textbf\\{)(.*?)(\\})\", \"\\\\2\", final_answer)\n    final_answer = re.sub(r\"(\\\\overline\\{)(.*?)(\\})\", \"\\\\2\", final_answer)\n    final_answer = re.sub(r\"(\\\\boxed\\{)(.*)(\\})\", \"\\\\2\", final_answer)\n\n    # Normalize shorthand TeX:\n    #  \\fracab -> \\frac{a}{b}\n    #  \\frac{abc}{bef} -> \\frac{abc}{bef}\n    #  \\fracabc -> \\frac{a}{b}c\n    #  \\sqrta -> \\sqrt{a}\n    #  \\sqrtab -> sqrt{a}b\n    final_answer = re.sub(r\"(frac)([^{])(.)\", \"frac{\\\\2}{\\\\3}\", final_answer)\n    final_answer = re.sub(r\"(sqrt)([^{])\", \"sqrt{\\\\2}\", final_answer)\n    final_answer = final_answer.replace(\"$\", \"\")\n\n    # Normalize numbers\n    if final_answer.replace(\",\", \"\").isdigit():\n        final_answer = final_answer.replace(\",\", \"\")\n\n    return final_answer.strip()\n\n\ndef is_correct_minerva(solution_str: str, gt: str, gt_need_extract: bool = False, answer_pattern: str = r\"(?i)Answer\\s*:\\s*([^\\n]+)\") -> tuple[bool, str]:\n    \"\"\"Check if the solution is correct according to Minerva criteria.\n\n    Args:\n        solution_str: The solution string to check\n        gt: The ground truth answer\n        gt_need_extract: Whether the ground truth needs extraction\n        answer_pattern: Regex pattern to extract the answer\n\n    Returns:\n        Tuple of (is_correct, normalized_prediction)\n    \"\"\"\n    # Extract answer from solution\n    match = re.findall(answer_pattern, solution_str)\n    extracted_answer = match[-1] if match else \"[INVALID]\"\n    pred = normalize_final_answer(extracted_answer)\n\n    # Process ground truth\n    if gt_need_extract:\n        gt = normalize_final_answer(remove_boxed(last_boxed_only_string(gt)))\n    else:\n        gt = normalize_final_answer(gt)\n\n    return (pred == gt), pred\n\n\ndef is_correct_strict_box(pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None) -> tuple[int, Optional[str]]:\n    \"\"\"Check if the prediction is correct using strict boxed answer criteria.\n\n    Args:\n        pred: The prediction string\n        gt: The ground truth answer\n        pause_tokens_index: Indices of pause tokens\n\n    Returns:\n        Tuple of (score, extracted_prediction)\n    \"\"\"\n    # Extract the relevant part of the prediction\n    if pause_tokens_index is not None:\n        assert len(pause_tokens_index) == 4\n        pred = pred[pause_tokens_index[-1] - 100 :]\n    else:\n        pred = pred[-100:]\n\n    # Extract and check the boxed answer\n    boxed_pred = last_boxed_only_string(pred)\n    extracted_pred = remove_boxed(boxed_pred) if boxed_pred is not None else None\n\n    return 1 if (extracted_pred == gt) else -1, extracted_pred\n\n\ndef verify(solution_str: str, answer: str, strict_box_verify: bool = False, pause_tokens_index: Optional[list[int]] = None) -> bool:\n    \"\"\"Verify if the solution is correct.\n\n    Args:\n        solution_str: The solution string to verify\n        answer: The ground truth answer\n        strict_box_verify: Whether to use strict box verification\n        pause_tokens_index: Indices of pause tokens\n\n    Returns:\n        True if the solution is correct, False otherwise\n    \"\"\"\n    if strict_box_verify:\n        correct, pred = is_correct_strict_box(solution_str, answer, pause_tokens_index)\n        return correct == 1, pred\n\n    correct, pred = is_correct_minerva(solution_str, answer)\n    return correct, pred\n\n\ndef compute_score(\n    solution_str: str,\n    ground_truth: str,\n    strict_box_verify: bool = False,\n    pause_tokens_index: Optional[list[int]] = None,\n) -> float:\n    \"\"\"Compute the reward score for a solution.\n\n    Args:\n        solution_str: The solution string\n        ground_truth: The ground truth answer\n        strict_box_verify: Whether to use strict box verification\n        pause_tokens_index: Indices of pause tokens\n\n    Returns:\n        Reward score (1.0 for correct, -1.0 for incorrect)\n    \"\"\"\n    # Limit solution length for efficiency\n    solution_str = solution_str[-300:]  # The longest answer in MATH-500 has 159 characters\n\n    # Verify the solution\n    correct, pred = verify(solution_str, ground_truth, strict_box_verify, pause_tokens_index)\n\n    reward = 1.0 if correct else -1.0\n    acc = correct\n\n    return {\n        \"score\": reward,\n        \"acc\": acc,\n        \"pred\": pred,\n    }\n"
  },
  {
    "path": "siirl/utils/reward_score/math_verify.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\ntry:\n    from math_verify.errors import TimeoutException\n    from math_verify.metric import math_metric\n    from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig\nexcept ImportError:\n    print(\"To use Math-Verify, please install it first by running `pip install math-verify`.\")\n\n\ndef compute_score(model_output: str, ground_truth: str, timeout_score: float = 0) -> bool:\n    verify_func = math_metric(\n        gold_extraction_target=(LatexExtractionConfig(),),\n        pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),\n    )\n    ret_score = 0.0\n\n    # Wrap the ground truth in \\boxed{} format for verification\n    ground_truth_boxed = \"\\\\boxed{\" + ground_truth + \"}\"\n    try:\n        ret_score, _ = verify_func([ground_truth_boxed], [model_output])\n    except Exception:\n        pass\n    except TimeoutException:\n        ret_score = timeout_score\n\n    return ret_score\n"
  },
  {
    "path": "siirl/utils/reward_score/mm_eureka.py",
    "content": "import os\nimport re\nfrom datetime import datetime\n\nfrom loguru import logger\nimport torch\nfrom math_verify import ExprExtractionConfig, LatexExtractionConfig, StringExtractionConfig, parse\nfrom siirl.utils.extras.patch import verify\n\nchoices = [\"a\", \"b\", \"c\", \"d\"]\n\n\ndef extract_answer_with_tags(text):\n    match = re.search(r\"(<answer>.*?</answer>)\", text)\n    if match:\n        return match.group(1)\n    return None\n\n\ndef accuracy_reward_func(completion, answer):\n    reward = 0.0\n    response = extract_answer_with_tags(completion)\n    if response != None:\n        response = response\n    else:\n        try:\n            response = completion.split(\"<answer>\")[-1]\n        except:\n            response = completion.split(\"\\n\")[-1]\n\n    content, sol = response, answer\n    answer_parsed = content\n    gold_parsed = parse(sol)\n    if len(gold_parsed) != 0:\n        answer_parsed = parse(\n            content,\n            extraction_config=[StringExtractionConfig(), LatexExtractionConfig(), ExprExtractionConfig()],\n        )\n        try:\n            reward = float(verify(answer_parsed, gold_parsed))\n        except Exception:\n            pass\n\n        if reward == 0.0:\n            try:\n                content_match = re.search(r\"<answer>(.*?)</answer>\", completion)\n                student_answer = content_match.group(1).strip() if content_match else content.strip()\n                student_answer = student_answer.replace(\"</answer>\", \"\").replace(\"<answer>\", \"\").strip()\n                for answer in gold_parsed:\n                    if str(answer).lower() in choices:\n                        if str(answer).lower() in student_answer.lower():\n                            choices_other = [choice for choice in choices if choice != str(answer).lower()]\n                            if all(choice not in student_answer.lower() for choice in choices_other):\n                                reward = 1.0\n            except Exception:\n                pass\n    else:\n        reward = 1.0\n        print(\"Failed to parse gold solution: \", sol)\n\n    return reward, answer_parsed\n\n\ndef format_reward_func(completion, **kwargs):\n    pattern = (\n        r\"^(?=(?:.*<think>){1})(?=(?:.*<\\/think>){1})\"\n        r\"(?=(?:.*<answer>){1})(?=(?:.*<\\/answer>){1})\"\n        r\"(?!.*<think>.*<think>)\"\n        r\"(?!.*<\\/think>.*<\\/think>)\"\n        r\"(?!.*<answer>.*<answer>)\"\n        r\"(?!.*<\\/answer>.*<\\/answer>)\"\n        r\".*<think>(.+?)</think>\\s*<answer>.+?</answer>.*$\"\n    )\n    matches = re.search(pattern, completion, re.DOTALL)\n    return 0.5 if matches else 0.0\n\n\ndef compute_score(predict_str: str, ground_truth: str) -> float:\n    try:\n        accuracy_reward, answer_parsed = accuracy_reward_func(predict_str, ground_truth)\n        format_reward = format_reward_func(predict_str)\n    except:\n        logger.warning(f\"Error in computing rewards for prediction: {predict_str}\")\n        accuracy_reward = 0.0\n        format_reward = 0.0\n        answer_parsed = \"\"\n    # LOG_PATH = os.environ.get(\"REWARD_LOG_PATH\", \"reward.log\")\n    # with open(LOG_PATH, \"a\") as f:\n    #     f.write(f\"===============================================================\\n\")\n    #     f.write(\"【Predict Str】: \" + predict_str + \"\\n\")\n    #     f.write(\"【Answer】: \" + ground_truth + \"\\n\")\n    #     f.write(f\"【Accuracy Reward】: {accuracy_reward}\\tFormat Reward: {format_reward}\\n\")\n    #     f.write(f\"===============================================================\\n\")\n    return 0.9 * accuracy_reward + 0.1 * format_reward\n"
  },
  {
    "path": "siirl/utils/reward_score/prime_code/__init__.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport traceback\n\nfrom .utils import check_correctness as apps_check_correctness\n\n\ndef compute_score(completion, test_cases, continuous=False):\n    # try to get code solution from completion. if the completion is pure code, this will not take effect.\n    solution = completion.split(\"```python\")[-1].split(\"```\")[0]\n    success = False\n    metadata_list = None\n    try:\n        try:\n            if not isinstance(test_cases, dict):\n                test_cases = json.loads(test_cases)\n        except Exception as e:\n            print(f\"Error:{e}\")\n\n        # Complete check on all in-out pairs first. If there is no failure, per-sample test can be skipped.\n        try:\n            res, metadata = apps_check_correctness(in_outs=test_cases, generation=solution, timeout=5, debug=False)\n            metadata = dict(enumerate(metadata))[0]\n            success = all(map(lambda x: x is True, res))\n            if success:\n                return success, metadata\n        except Exception:\n            pass\n\n        test_cases_list = []\n        inputs = test_cases[\"inputs\"]\n        outputs = test_cases[\"outputs\"]\n        for i in range(len(inputs)):\n            test_cases_list.append({\"inputs\": [inputs[i]], \"outputs\": [outputs[i]]})\n\n        if continuous:\n            # per sample test: if continuous score is needed, test first 10 samples regardless of failures\n            # do not test all samples cuz some problems have enormous test cases\n            metadata_list = []\n            res_list = []\n            for test_case_id, test_case in enumerate(test_cases_list):\n                res, metadata = apps_check_correctness(in_outs=test_case, generation=solution, timeout=10, debug=False)\n                try:\n                    metadata = dict(enumerate(metadata))[0]  # metadata can be empty occasionally\n                except Exception:\n                    metadata = {}\n                metadata[\"test_case\"] = {}\n                metadata[\"test_case\"][\"input\"] = str(test_case[\"inputs\"][0])\n                metadata[\"test_case\"][\"output\"] = str(test_case[\"outputs\"][0])\n                metadata[\"test_case\"][\"res\"] = str(res)\n                metadata_list.append(metadata)\n                res_list.extend(res)\n\n                if test_case_id >= 9:\n                    break\n            res_count = len(res_list) if len(res_list) > 0 else 1\n            success = sum(map(lambda x: x is True, res_list)) / res_count\n    except Exception:\n        traceback.print_exc(10)\n        success = False\n        metadata_list = None\n    return success, metadata_list\n"
  },
  {
    "path": "siirl/utils/reward_score/prime_code/testing_util.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ast\nimport faulthandler\nimport json\nimport platform\n\n# to run the solution files we're using a timing based approach\nimport signal\nimport sys\nimport traceback\n\n# used for debugging to time steps\nfrom datetime import datetime\nfrom enum import Enum\n\n# for capturing the stdout\nfrom io import StringIO\n\n# used for testing the code that reads from input\nfrom unittest.mock import mock_open, patch\n\nimport numpy as np\nfrom pyext import RuntimeModule\n\n\ndef truncatefn(s, length=300):\n    assert isinstance(s, str)\n    if len(s) <= length:\n        return s\n\n    return s[: length // 2] + \"...(truncated) ...\" + s[-length // 2 :]\n\n\nclass CODE_TYPE(Enum):\n    call_based = 0\n    standard_input = 1\n\n\n# used to capture stdout as a list\n# from https://stackoverflow.com/a/16571630/6416660\n# alternative use redirect_stdout() from contextlib\nclass Capturing(list):\n    def __enter__(self):\n        self._stdout = sys.stdout\n        sys.stdout = self._stringio = StringIO()\n        # Make closing the StringIO a no-op\n        self._stringio.close = lambda x: 1\n        return self\n\n    def __exit__(self, *args):\n        self.append(self._stringio.getvalue())\n        del self._stringio  # free up some memory\n        sys.stdout = self._stdout\n\n\ndef only_int_check(val):\n    return isinstance(val, int)\n\n\ndef string_int_check(val):\n    return isinstance(val, str) and val.isdigit()\n\n\ndef combined_int_check(val):\n    return only_int_check(val) or string_int_check(val)\n\n\ndef clean_traceback(error_traceback):\n    file_start = error_traceback.find('File \"<string>\"')\n    # print(file_start)\n    error_traceback = \"Traceback (most recent call last):\\n  \" + error_traceback[file_start:]\n    return error_traceback\n\n\ndef run_test(in_outs, test=None, debug=False, timeout=15):\n    \"\"\"\n    if test(generated_code) is not None it'll try to run the code.\n    otherwise it'll just return an input and output pair.\n    \"\"\"\n    # Disable functionalities that can make destructive changes to the test.\n    reliability_guard()\n\n    if debug:\n        print(f\"start = {datetime.now().time()}\")\n\n    if in_outs:\n        if in_outs.get(\"fn_name\") is None:\n            which_type = CODE_TYPE.standard_input  # Standard input\n            method_name = None\n        else:\n            which_type = CODE_TYPE.call_based  # Call-based\n            method_name = in_outs[\"fn_name\"]\n\n    if debug:\n        print(f\"loaded input_output = {datetime.now().time()}\")\n\n    if test is None:\n        raise AssertionError(\"should not happen: test code is none\")\n    elif test is not None:\n        results = []\n        sol = \"from string import *\\nfrom re import *\\nfrom datetime import *\\nfrom collections import *\\nfrom heapq import *\\nfrom bisect import *\\nfrom copy import *\\nfrom math import *\\nfrom random import *\\nfrom statistics import *\\nfrom itertools import *\\nfrom functools import *\\nfrom operator import *\\nfrom io import *\\nfrom sys import *\\nfrom json import *\\nfrom builtins import *\\nfrom typing import *\\nimport string\\nimport re\\nimport datetime\\nimport collections\\nimport heapq\\nimport bisect\\nimport copy\\nimport math\\nimport random\\nimport statistics\\nimport itertools\\nimport functools\\nimport operator\\nimport io\\nimport sys\\nimport json\\nsys.setrecursionlimit(6*10**5)\\n\"  # noqa: E501\n        if debug:\n            print(f\"loading test code = {datetime.now().time()}\")\n\n        if which_type == CODE_TYPE.call_based:\n            sol += test\n            if debug:\n                print(f\"sol = {sol}\")\n            signal.alarm(timeout)\n            try:\n                tmp_sol = RuntimeModule.from_string(\"tmp_sol\", \"\", sol)\n                tmp = tmp_sol if \"class Solution\" not in test else tmp_sol.Solution()\n                signal.alarm(0)\n            except Exception as e:\n                signal.alarm(0)\n                error_traceback = traceback.format_exc()\n                if debug:\n                    print(f\"type 0 compilation error = {e}\")\n                results.append(-2)\n                return results, {\n                    \"error\": repr(e),\n                    # \"error_code\": -1,\n                    # \"error_message\": \"Compilation Error\",\n                    \"traceback\": clean_traceback(error_traceback),\n                }\n            signal.alarm(0)\n\n        elif which_type == CODE_TYPE.standard_input:\n            # sol\n            # if code has if __name__ == \"__main__\": then remove it\n            try:\n                astree = ast.parse(test)\n                last_block = astree.body[-1]\n                if isinstance(last_block, ast.If):\n                    condition = last_block.test\n                    if ast.unparse(condition).strip() == \"__name__ == '__main__'\":\n                        test = ast.unparse(astree.body[:-1]) + \"\\n\" + ast.unparse(last_block.body)\n            except Exception:\n                pass\n\n            tmp_test = test.split(\"\\n\")\n\n            new_test = []\n            for x in tmp_test:\n                if (not x.startswith(\"from \")) and (not x.startswith(\"import \")):\n                    new_test.append(\"\\t\" + x + \"\\n\")\n                else:\n                    new_test.append(x + \"\\n\")\n            tmp_test = new_test\n\n            new_test = \"\"\n            started = False\n            for i in tmp_test:\n                if i.startswith(\"\\t\") and not started:\n                    new_test += \"stdin = sys.stdin\\nstdout = sys.stdout\\n\"\n                    new_test += \"def code():\\n\"\n                    new_test += i\n                    started = True\n                elif started and ((i.startswith(\"from \")) or (i.startswith(\"import \"))):\n                    new_test += \"\\t\" + i\n                else:\n                    new_test += i\n            tmp_test = new_test\n\n            sol += tmp_test\n            if debug:\n                print(f\"sol = {sol}\")\n            method_name = \"code\"\n            signal.alarm(timeout)\n            try:\n                tmp_sol = RuntimeModule.from_string(\"tmp_sol\", \"\", sol)\n                tmp = tmp_sol\n                signal.alarm(0)\n            except Exception as e:\n                signal.alarm(0)\n                error_traceback = traceback.format_exc()\n                if debug:\n                    print(f\"type 1 compilation error = {e}\")\n                results.append(-2)\n                return results, {\n                    \"error\": repr(e),\n                    # \"error_code\": -1,\n                    # \"error_message\": \"Compilation Error\",\n                    \"traceback\": clean_traceback(error_traceback),\n                }\n            signal.alarm(0)\n        if debug:\n            print(f\"get method = {datetime.now().time()}\")\n\n        try:\n            method = getattr(tmp, method_name)  # get_attr second arg must be str\n        except Exception:\n            signal.alarm(0)\n            error_traceback = traceback.format_exc()\n            error_info = sys.exc_info()\n            print(f\"unable to get function error = {error_info}\")\n            results.append(-2)\n            return results, {\n                \"error\": repr(error_info),\n                # \"error_code\": -1,\n                # \"error_message\": \"Unable to extract code\",\n                \"traceback\": clean_traceback(error_traceback),\n            }\n\n        for index, inputs in enumerate(in_outs[\"inputs\"]):\n            raw_inputs = inputs\n            raw_outputs = in_outs[\"outputs\"][index]\n            if which_type == CODE_TYPE.call_based:\n                inputs = [json.loads(line) for line in inputs.split(\"\\n\")]\n                in_outs[\"outputs\"][index] = json.loads(in_outs[\"outputs\"][index])\n\n                truncate_line_size = 300 // (raw_inputs.count(\"\\n\") + 1)\n                raw_inputs = \"\\n\".join([truncatefn(line, truncate_line_size) for line in raw_inputs.strip().split(\"\\n\")])\n                raw_outputs = truncatefn(raw_outputs, 200)\n            else:\n                raw_inputs = truncatefn(raw_inputs)\n                raw_outputs = truncatefn(raw_outputs, 200)\n            # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)\n            try:\n                if isinstance(inputs[0], dict):\n                    inputs = [{int(k): v for k, v in inputs[0].items()}]\n            except Exception:\n                pass\n            try:\n                if isinstance(in_outs[\"outputs\"][index], dict):\n                    in_outs[\"outputs\"][index] = [{int(k): v for k, v in in_outs[\"outputs\"][index].items()}]\n            except Exception:\n                pass\n            try:\n                if isinstance(in_outs[\"outputs\"][index][0], dict):\n                    in_outs[\"outputs\"][index] = [{int(k): v for k, v in in_outs[\"outputs\"][index][0].items()}]\n            except Exception:\n                pass\n\n            if debug:\n                print(f\"time: {datetime.now().time()} testing index = {index}  inputs = {inputs}, {type(inputs)}. type = {which_type}\")\n            if which_type == CODE_TYPE.call_based:  # Call-based\n                signal.alarm(timeout)\n                faulthandler.enable()\n                try:\n                    output = method(*inputs)\n                    raw_true_output = output\n\n                    raw_true_output_copy = json.dumps(output)\n                    raw_true_output_copy = truncatefn(raw_true_output_copy, 200)\n\n                    # ground truth sequences are not tuples\n                    if isinstance(output, tuple):\n                        output = list(output)\n\n                    tmp_result = output == in_outs[\"outputs\"][index]\n                    if isinstance(in_outs[\"outputs\"][index], list) and in_outs[\"outputs\"][index]:\n                        tmp_result = tmp_result or (output == in_outs[\"outputs\"][index][0])\n\n                    # ground truth sequences are not tuples\n                    try:\n                        if isinstance(output[0], tuple):\n                            tmp_result = tmp_result or ([list(x) for x in output] == in_outs[\"outputs\"][index][0])\n                    except Exception:\n                        pass\n                    results.append(tmp_result)\n                    if tmp_result is not True:\n                        return results, {\n                            \"output\": raw_true_output_copy,\n                            \"expected\": raw_outputs,\n                            \"inputs\": raw_inputs,\n                            # \"error_code\": -2,\n                            \"error_message\": \"Wrong Answer\",\n                        }\n                    # reset the alarm\n                    signal.alarm(0)\n                except Exception as e:\n                    signal.alarm(0)\n                    error_traceback = traceback.format_exc()\n                    faulthandler.disable()\n                    if debug:\n                        print(f\"Standard input runtime error or time limit exceeded error = {e}\")\n                    results.append(-1)\n                    return results, {\n                        \"error\": repr(e),\n                        \"traceback\": clean_traceback(error_traceback),\n                    }\n                faulthandler.disable()\n                signal.alarm(0)\n                if debug:\n                    print(f\"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}\")\n            elif which_type == CODE_TYPE.standard_input:  # Standard input\n                faulthandler.enable()\n                passed = False\n\n                if isinstance(inputs, list):\n                    inputs = \"\\n\".join(inputs)\n                if isinstance(in_outs[\"outputs\"][index], list):\n                    in_outs[\"outputs\"][index] = \"\\n\".join(in_outs[\"outputs\"][index])\n\n                signal.alarm(timeout)\n                with Capturing() as output:\n                    try:\n                        call_method(method, inputs)\n                        # reset the alarm\n                        signal.alarm(0)\n                        passed = True\n                    except Exception as e:\n                        # runtime error or took too long\n                        signal.alarm(0)\n                        error_traceback = traceback.format_exc()\n                        print(f\"Call-based runtime error or time limit exceeded error = {repr(e)}{e}\")\n                        results.append(-1)\n                        return results, {\n                            \"error\": repr(e),\n                            \"traceback\": clean_traceback(error_traceback),\n                        }\n                    signal.alarm(0)\n                raw_true_output = output[0]\n                raw_true_output_copy = truncatefn(raw_true_output, 200)\n                output = raw_true_output.splitlines()\n                if not passed:\n                    if debug:\n                        nl = \"\\n\"\n                        if not isinstance(inputs, list):\n                            print(f\"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}\")\n                        else:\n                            print(f\"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}\")\n                    continue\n\n                if passed and debug:\n                    print(f\"==> output = {output}, test outputs = {in_outs['outputs'][index]}\")\n\n                if custom_compare_(output, in_outs[\"outputs\"][index]):\n                    tmp_result = True\n                    results.append(tmp_result)\n                    continue\n\n                # ground truth sequences are expressed as lists not tuples\n                if isinstance(output, tuple):\n                    output = list(output)\n\n                tmp_result = False\n                try:\n                    tmp_result = output == [in_outs[\"outputs\"][index]]\n                    if isinstance(in_outs[\"outputs\"][index], list):\n                        tmp_result = tmp_result or (output == in_outs[\"outputs\"][index])\n                        if isinstance(output[0], str):\n                            tmp_result = tmp_result or ([e.strip() for e in output] == in_outs[\"outputs\"][index])\n                except Exception as e:\n                    if debug:\n                        print(f\"Failed check1 exception = {e}\")\n                    pass\n\n                if tmp_result is True:\n                    results.append(tmp_result)\n                    continue\n\n                # try one more time without \\n\n                if isinstance(in_outs[\"outputs\"][index], list):\n                    for tmp_index, i in enumerate(in_outs[\"outputs\"][index]):\n                        in_outs[\"outputs\"][index][tmp_index] = i.split(\"\\n\")\n                        in_outs[\"outputs\"][index][tmp_index] = [x.strip() for x in in_outs[\"outputs\"][index][tmp_index] if x]\n                else:\n                    in_outs[\"outputs\"][index] = in_outs[\"outputs\"][index].split(\"\\n\")\n                    in_outs[\"outputs\"][index] = list(filter(len, in_outs[\"outputs\"][index]))\n                    in_outs[\"outputs\"][index] = list(map(lambda x: x.strip(), in_outs[\"outputs\"][index]))\n\n                try:\n                    tmp_result = output == [in_outs[\"outputs\"][index]]\n                    if isinstance(in_outs[\"outputs\"][index], list):\n                        tmp_result = tmp_result or (output == in_outs[\"outputs\"][index])\n                except Exception as e:\n                    if debug:\n                        print(f\"Failed check2 exception = {e}\")\n                    pass\n\n                if tmp_result is True:\n                    results.append(tmp_result)\n                    continue\n\n                # try by converting the output into a split up list too\n                if isinstance(output, list):\n                    output = list(filter(len, output))\n\n                if debug:\n                    nl = \"\\n\"\n                    if not isinstance(inputs, list):\n                        print(f\"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}\")\n                    else:\n                        print(f\"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}\")\n\n                if debug:\n                    print(f\"{tmp_result=} @a\")\n\n                try:\n                    tmp_result = output == [in_outs[\"outputs\"][index]]\n                    if isinstance(in_outs[\"outputs\"][index], list):\n                        tmp_result = tmp_result or (output == in_outs[\"outputs\"][index])\n                except Exception as e:\n                    if debug:\n                        print(f\"Failed check3 exception = {e}\")\n                    pass\n\n                if debug:\n                    print(f\"{tmp_result=} @b\")\n\n                try:\n                    all_ints = all(combined_int_check(e1) and combined_int_check(e2) for e1, e2 in zip(output, in_outs[\"outputs\"][index]))\n                    if not all_ints:\n                        if debug:\n                            print([combined_int_check(e1) and combined_int_check(e2) for e1, e2 in zip(output, in_outs[\"outputs\"][index])])\n                        output_float = [float(e) for e in output]\n                        gt_float = [float(e) for e in in_outs[\"outputs\"][index]]\n                        tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))\n                except Exception:\n                    pass\n\n                if debug:\n                    print(f\"{tmp_result=} @c\")\n\n                try:\n                    if isinstance(output[0], list):\n                        all_ints = all(combined_int_check(e1) and combined_int_check(e2) for e1, e2 in zip(output[0], in_outs[\"outputs\"][index]))\n                        if not all_ints:\n                            output_float = [float(e) for e in output[0]]\n                            gt_float = [float(e) for e in in_outs[\"outputs\"][index][0]]\n                            tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))\n                except Exception:\n                    pass\n\n                if tmp_result is True:\n                    results.append(tmp_result)\n                    continue\n\n                if debug:\n                    print(f\"{tmp_result=} @d\")\n                # try by converting the stuff into split up list\n                if isinstance(in_outs[\"outputs\"][index], list):\n                    for tmp_index, i in enumerate(in_outs[\"outputs\"][index]):\n                        in_outs[\"outputs\"][index][tmp_index] = set(i.split())\n                else:\n                    in_outs[\"outputs\"][index] = set(in_outs[\"outputs\"][index].split())\n\n                if debug:\n                    print(f\"{tmp_result=} @e\")\n\n                try:\n                    tmp_result = output == in_outs[\"outputs\"][index]\n                except Exception as e:\n                    if debug:\n                        print(f\"Failed check4 exception = {e}\")\n                    continue\n\n                if tmp_result is True:\n                    results.append(tmp_result)\n                    continue\n\n                if debug:\n                    print(f\"{tmp_result=} @f\")\n\n                # try by converting the output into a split up list too\n                if isinstance(output, list):\n                    for tmp_index, i in enumerate(output):\n                        output[tmp_index] = i.split()\n                    output = list(filter(len, output))\n                    for tmp_index, i in enumerate(output):\n                        output[tmp_index] = set(i)\n                else:\n                    output = output.split()\n                    output = list(filter(len, output))\n                    output = set(output)\n\n                if debug:\n                    print(f\"{tmp_result=} @g\")\n\n                if tmp_result is True and debug:\n                    print(\"PASSED\")\n\n                results.append(tmp_result)\n                if tmp_result is not True:\n                    return results, {\n                        \"output\": raw_true_output_copy,\n                        \"expected\": raw_outputs,\n                        \"inputs\": raw_inputs,\n                        # \"error_code\": -2,\n                        \"error_message\": \"Wrong Answer\",\n                    }\n\n                if debug:\n                    nl = \"\\n\"\n                    if not isinstance(inputs, list):\n                        print(f\"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl, ' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}\")\n                    else:\n                        print(f\"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}\")\n\n                    print(f\"results = {results}\")\n\n    return results, {}\n\n\ndef custom_compare_(output, ground_truth):\n    if isinstance(output, list):\n        output_1 = \"\\n\".join(output)\n        if stripped_string_compare(output_1, ground_truth):\n            return True\n\n    if isinstance(output, list):\n        output_2 = [o.lstrip().rstrip() for o in output]\n        output_2 = \"\\n\".join(output_2)\n        if stripped_string_compare(output_2, ground_truth):\n            return True\n\n    return False\n\n\ndef stripped_string_compare(s1, s2):\n    s1 = s1.lstrip().rstrip()\n    s2 = s2.lstrip().rstrip()\n    return s1 == s2\n\n\ndef call_method(method, inputs):\n    if isinstance(inputs, list):\n        inputs = \"\\n\".join(inputs)\n\n    inputs_line_iterator = iter(inputs.split(\"\\n\"))\n\n    # sys.setrecursionlimit(10000)\n\n    # @patch('builtins.input', side_effect=inputs.split(\"\\n\"))\n    @patch(\"builtins.open\", mock_open(read_data=inputs))\n    @patch(\"sys.stdin\", StringIO(inputs))\n    @patch(\"sys.stdin.readline\", lambda *args: next(inputs_line_iterator))\n    @patch(\"sys.stdin.readlines\", lambda *args: inputs.split(\"\\n\"))\n    @patch(\"sys.stdin.read\", lambda *args: inputs)\n    # @patch('sys.stdout.write', print)\n    def _inner_call_method(_method):\n        try:\n            return _method()\n        except SystemExit:\n            pass\n        finally:\n            pass\n\n    return _inner_call_method(method)\n\n\ndef reliability_guard(maximum_memory_bytes=None):\n    \"\"\"\n    This disables various destructive functions and prevents the generated code\n    from interfering with the test (e.g. fork bomb, killing other processes,\n    removing filesystem files, etc.)\n    WARNING\n    This function is NOT a security sandbox. Untrusted code, including, model-\n    generated code, should not be blindly executed outside of one. See the\n    Codex paper for more information about OpenAI's code sandbox, and proceed\n    with caution.\n    \"\"\"\n\n    if maximum_memory_bytes is not None:\n        import resource\n\n        resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))\n        resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))\n        if platform.uname().system != \"Darwin\":\n            resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))\n\n    faulthandler.disable()\n\n    import builtins\n\n    builtins.exit = None\n    builtins.quit = None\n\n    import os\n\n    os.environ[\"OMP_NUM_THREADS\"] = \"1\"\n\n    os.kill = None\n    os.system = None\n    os.putenv = None\n    os.remove = None\n    os.removedirs = None\n    os.rmdir = None\n    os.fchdir = None\n    os.setuid = None\n    os.fork = None\n    os.forkpty = None\n    os.killpg = None\n    os.rename = None\n    os.renames = None\n    os.truncate = None\n    os.replace = None\n    os.unlink = None\n    os.fchmod = None\n    os.fchown = None\n    os.chmod = None\n    os.chown = None\n    os.chroot = None\n    os.lchflags = None\n    os.lchmod = None\n    os.lchown = None\n    os.getcwd = None\n    os.chdir = None\n\n    import shutil\n\n    shutil.rmtree = None\n    shutil.move = None\n    shutil.chown = None\n\n    import subprocess\n\n    subprocess.Popen = None  # type: ignore\n\n    __builtins__[\"help\"] = None\n\n    import sys\n\n    sys.modules[\"ipdb\"] = None\n    sys.modules[\"joblib\"] = None\n    sys.modules[\"resource\"] = None\n    sys.modules[\"psutil\"] = None\n    sys.modules[\"tkinter\"] = None\n"
  },
  {
    "path": "siirl/utils/reward_score/prime_code/utils.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Borrowed from: https://huggingface.co/spaces/codeparrot/apps_metric/blob/main/utils.py\n\nimport multiprocessing\nimport os\nimport sys\nimport traceback\nfrom typing import Optional\n\nfrom siirl.utils.reward_score.prime_code.testing_util import run_test\n\n\ndef _temp_run(sample, generation, debug, result, metadata_list, timeout):\n    with open(os.devnull, \"w\") as devnull:\n        sys.stdout = devnull\n        sys.stderr = devnull\n        try:\n            res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout)\n            result.append(res)\n            metadata_list.append(metadata)\n        except Exception:\n            # print(e) # some tracebacks are extremely long.\n            traceback.print_exc(10)\n            result.append([-1 for i in range(len(sample[\"inputs\"]))])\n            metadata_list.append({})\n\n\ndef check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True):\n    \"\"\"Check correctness of code generation with a global timeout.\n    The global timeout is to catch some extreme/rare cases not handled by the timeouts\n    inside `run_test`\"\"\"\n\n    manager = multiprocessing.Manager()\n    result = manager.list()\n    metadata_list = manager.list()\n    p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout))\n    p.start()\n    p.join(timeout=timeout + 1)\n    if p.is_alive():\n        p.kill()\n        # p.terminate()\n    if not result:\n        # consider that all tests failed\n        result = [[-1 for i in range(len(in_outs[\"inputs\"]))]]\n        if debug:\n            print(\"global timeout\")\n    return result[0], metadata_list\n"
  },
  {
    "path": "siirl/utils/reward_score/prime_math/__init__.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nAnswer checker API that uses sympy to simplify expressions and check for equality.\n\nCall grade_answer(given_answer: str, ground_truth: str).\n\nFROM: https://github.com/openai/prm800k/blob/main/prm800k/grading/grader.py\n\"\"\"\n\nimport contextlib\nimport math\nimport re\n\nimport sympy\nfrom pylatexenc import latex2text\nfrom sympy.parsing import sympy_parser\n\nfrom siirl.utils.extras.py_functional import timeout_limit\n\nfrom . import math_normalize\nfrom .grader import math_equal\n\n# import math_normalize\n# from grader import math_equal\n\n# sympy might hang -- we don't care about trying to be lenient in these cases\nBAD_SUBSTRINGS = [\"^{\", \"^(\"]\nBAD_REGEXES = [\"\\^[0-9]+\\^\", \"\\^[0-9][0-9]+\"]\nTUPLE_CHARS = \"()[]\"\n\n\ndef _sympy_parse(expr: str):\n    \"\"\"Parses an expression with sympy.\"\"\"\n    py_expr = expr.replace(\"^\", \"**\")\n    return sympy_parser.parse_expr(\n        py_expr,\n        transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)),\n    )\n\n\ndef _parse_latex(expr: str) -> str:\n    \"\"\"Attempts to parse latex to an expression sympy can read.\"\"\"\n    expr = expr.replace(\"\\\\tfrac\", \"\\\\frac\")\n    expr = expr.replace(\"\\\\dfrac\", \"\\\\frac\")\n    expr = expr.replace(\"\\\\frac\", \" \\\\frac\")  # Play nice with mixed numbers.\n    expr = latex2text.LatexNodes2Text().latex_to_text(expr)\n\n    # Replace the specific characters that this parser uses.\n    expr = expr.replace(\"√\", \"sqrt\")\n    expr = expr.replace(\"π\", \"pi\")\n    expr = expr.replace(\"∞\", \"inf\")\n    expr = expr.replace(\"∪\", \"U\")\n    expr = expr.replace(\"·\", \"*\")\n    expr = expr.replace(\"×\", \"*\")\n\n    return expr.strip()\n\n\ndef _is_float(num: str) -> bool:\n    try:\n        float(num)\n        return True\n    except ValueError:\n        return False\n\n\ndef _is_int(x: float) -> bool:\n    try:\n        return abs(x - int(round(x))) <= 1e-7\n    except Exception:\n        return False\n\n\ndef _is_frac(expr: str) -> bool:\n    return bool(re.search(r\"^-?[0-9]+.?/0*[1-9][0-9]*.?$\", expr))\n\n\ndef _str_is_int(x: str) -> bool:\n    try:\n        x = _strip_properly_formatted_commas(x)\n        x = float(x)\n        return abs(x - int(round(x))) <= 1e-7\n    except Exception:\n        return False\n\n\ndef _str_to_int(x: str) -> bool:\n    x = x.replace(\",\", \"\")\n    x = float(x)\n    return int(x)\n\n\ndef _inject_implicit_mixed_number(step: str):\n    \"\"\"\n    Automatically make a mixed number evalable\n    e.g. 7 3/4 => 7+3/4\n    \"\"\"\n    p1 = re.compile(\"([0-9]) +([0-9])\")\n    step = p1.sub(\"\\\\1+\\\\2\", step)  ## implicit mults\n    return step\n\n\ndef _strip_properly_formatted_commas(expr: str):\n    # We want to be careful because we don't want to strip tuple commas\n    p1 = re.compile(\"(\\d)(,)(\\d\\d\\d)($|\\D)\")\n    while True:\n        next_expr = p1.sub(\"\\\\1\\\\3\\\\4\", expr)\n        if next_expr == expr:\n            break\n        expr = next_expr\n    return next_expr\n\n\ndef _normalize(expr: str) -> str:\n    \"\"\"Normalize answer expressions.\"\"\"\n    if expr is None:\n        return None\n\n    # Remove enclosing `\\text{}`.\n    m = re.search(\"^\\\\\\\\text\\{(?P<text>.+?)\\}$\", expr)\n    if m is not None:\n        expr = m.group(\"text\")\n\n    expr = expr.replace(\"\\\\%\", \"%\")\n    expr = expr.replace(\"\\\\$\", \"$\")\n    expr = expr.replace(\"$\", \"\")\n    expr = expr.replace(\"%\", \"\")\n    expr = expr.replace(\" or \", \" , \")\n    expr = expr.replace(\" and \", \" , \")\n\n    expr = expr.replace(\"million\", \"*10^6\")\n    expr = expr.replace(\"billion\", \"*10^9\")\n    expr = expr.replace(\"trillion\", \"*10^12\")\n\n    for unit in [\n        \"degree\",\n        \"cm\",\n        \"centimeter\",\n        \"meter\",\n        \"mile\",\n        \"second\",\n        \"minute\",\n        \"hour\",\n        \"day\",\n        \"week\",\n        \"month\",\n        \"year\",\n        \"foot\",\n        \"feet\",\n        \"inch\",\n        \"yard\",\n        \"liter\",\n    ]:\n        expr = re.sub(f\"{unit}(es)?(s)? *(\\^[0-9]+)?\", \"\", expr)\n    expr = re.sub(\"\\^ *\\\\\\\\circ\", \"\", expr)\n\n    if len(expr) > 0 and expr[0] == \"{\" and expr[-1] == \"}\":\n        expr = expr[1:-1]\n\n    expr = re.sub(\",\\\\\\\\! *\", \"\", expr)\n    if _is_float(expr) and _is_int(float(expr)):\n        expr = str(int(round(float(expr))))\n    if \"\\\\\" in expr:\n        with contextlib.suppress(Exception):\n            expr = _parse_latex(expr)\n\n    # edge case with mixed numbers and negative signs\n    expr = re.sub(\"- *\", \"-\", expr)\n\n    expr = _inject_implicit_mixed_number(expr)\n\n    # don't be case sensitive for text answers\n    expr = expr.lower()\n\n    if _str_is_int(expr):\n        expr = str(_str_to_int(expr))\n\n    return expr\n\n\ndef count_unknown_letters_in_expr(expr: str):\n    expr = expr.replace(\"sqrt\", \"\")\n    expr = expr.replace(\"frac\", \"\")\n    letters_in_expr = set([x for x in expr if x.isalpha()])\n    return len(letters_in_expr)\n\n\ndef should_allow_eval(expr: str):\n    # we don't want to try parsing unknown text or functions of more than two variables\n    if count_unknown_letters_in_expr(expr) > 2:\n        return False\n\n    for bad_string in BAD_SUBSTRINGS:\n        if bad_string in expr:\n            return False\n\n    return all(re.search(bad_regex, expr) is None for bad_regex in BAD_REGEXES)\n\n\n@timeout_limit(seconds=10)\ndef are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):\n    are_equal = False\n    try:\n        expr = f\"({ground_truth_normalized})-({given_normalized})\"\n        if should_allow_eval(expr):\n            sympy_diff = _sympy_parse(expr)\n            simplified = sympy.simplify(sympy_diff)\n            if simplified == 0:\n                are_equal = True\n    except Exception:\n        pass\n    return are_equal\n\n\ndef split_tuple(expr: str):\n    \"\"\"\n    Split the elements in a tuple/interval, while handling well-formatted commas in large numbers\n    \"\"\"\n    expr = _strip_properly_formatted_commas(expr)\n    if len(expr) == 0:\n        return []\n    if len(expr) > 2 and expr[0] in TUPLE_CHARS and expr[-1] in TUPLE_CHARS and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]):\n        elems = [elem.strip() for elem in expr[1:-1].split(\",\")]\n    else:\n        elems = [expr]\n    return elems\n\n\ndef grade_answer(given_answer: str, ground_truth: str) -> bool:\n    \"\"\"\n    The answer will be considered correct if:\n    (a) it normalizes to the same string as the ground truth answer\n    OR\n    (b) sympy can simplify the difference between the expressions to 0\n    \"\"\"\n    if given_answer is None:\n        return False\n\n    ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth)\n    given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer)\n\n    # be at least as lenient as mathd\n    if ground_truth_normalized_mathd == given_answer_normalized_mathd:\n        return True\n\n    ground_truth_normalized = _normalize(ground_truth)\n    given_normalized = _normalize(given_answer)\n\n    if ground_truth_normalized is None:\n        return False\n\n    if ground_truth_normalized == given_normalized:\n        return True\n\n    if len(given_normalized) == 0:\n        return False\n\n    ground_truth_elems = split_tuple(ground_truth_normalized)\n    given_elems = split_tuple(given_normalized)\n\n    if len(ground_truth_elems) > 1 and (ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1]) or len(ground_truth_elems) != len(given_elems):\n        is_correct = False\n    else:\n        for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems):\n            if _is_frac(ground_truth_elem) and _is_frac(given_elem):\n                # if fractions aren't reduced, then shouldn't be marked as correct\n                # so, we don't want to allow sympy.simplify in this case\n                is_correct = ground_truth_elem == given_elem\n            elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):\n                # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify)\n                is_correct = False\n            else:\n                try:\n                    is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)\n                except Exception as e:\n                    # if there's an error, we'll just say it's not correct\n                    is_correct = False\n                    print(f\"Error: {e} from are_equal_under_sympy, {ground_truth_elem}, {given_elem}\")\n            if not is_correct:\n                break\n\n    return is_correct\n\n\ndef remove_boxed(s):\n    left = \"\\\\boxed{\"\n    try:\n        assert s[: len(left)] == left\n        assert s[-1] == \"}\"\n        return s[len(left) : -1]\n    except Exception:\n        return None\n\n\ndef _last_boxed_only_string(string):\n    idx = string.rfind(\"\\\\boxed\")\n    if idx < 0:\n        idx = string.rfind(\"\\\\fbox\")\n        if idx < 0:\n            return None\n\n    i = idx\n    left_brace_idx = None\n    right_brace_idx = None\n    num_left_braces_open = 0\n    while i < len(string):\n        if string[i] == \"{\":\n            num_left_braces_open += 1\n            if left_brace_idx is None:\n                left_brace_idx = i\n        elif string[i] == \"}\":\n            num_left_braces_open -= 1\n            if num_left_braces_open == 0:\n                right_brace_idx = i\n                break\n\n        i += 1\n\n    if left_brace_idx is None or right_brace_idx is None:\n        return None\n\n    return string[left_brace_idx + 1 : right_brace_idx].strip()\n\n\ndef match_answer(response):\n    is_matched = False\n    for ans_marker in [\"answer:\", \"answer is\", \"answers are\"]:\n        ans_idx = response.lower().rfind(ans_marker)\n        if ans_idx != -1:\n            is_matched = True\n            response = response[ans_idx + len(ans_marker) :].strip()\n            if response.endswith(\"\\n\"):\n                response = response[:-2]\n\n    for ans_marker in [\"is answer\", \"is the answer\", \"are answers\", \"are the answers\"]:\n        ans_idx = response.lower().rfind(ans_marker)\n        if ans_idx != -1:\n            is_matched = True\n            response = response[:ans_idx].strip()\n            if response.endswith(\"\\n\"):\n                response = response[:-2]\n\n    # Find boxed\n    ans_boxed = _last_boxed_only_string(response)\n    if ans_boxed:\n        is_matched = True\n        response = ans_boxed\n\n    if \". \" in response:\n        dot_idx = response.lower().rfind(\". \")\n        if dot_idx != -1:\n            response = response[:dot_idx].strip()\n\n    for ans_marker in [\"be \", \"is \", \"are \", \"=\", \": \", \"get \", \"be\\n\", \"is\\n\", \"are\\n\", \":\\n\", \"get\\n\"]:\n        ans_idx = response.lower().rfind(ans_marker)\n        if ans_idx != -1:\n            is_matched = True\n            response = response[ans_idx + len(ans_marker) :].strip()\n            if response.endswith(\"\\n\"):\n                response = response[:-2]\n\n    is_matched = is_matched if any([c.isdigit() for c in response]) else False  # answer must have a digit\n    # Grade\n    return is_matched, response\n\n\ndef compute_score(model_output: str, ground_truth: str) -> bool:\n    model_output = str(model_output)\n    ground_truth = str(ground_truth)\n\n    is_matched, extracted_model_output = match_answer(model_output)\n    format_correctness = \"Step 2:\" in model_output and \"\\\\box\" in model_output\n\n    # grade simple algebra questions. if succeeded, return; otherwise, proceed to more complex grading\n    if grade_answer(extracted_model_output, ground_truth):\n        return True, True, extracted_model_output\n\n    try:\n        if \"\\pi\" in extracted_model_output or \"\\pi\" in ground_truth:\n            equivs = []\n            for pi in [math.pi, 3.14]:\n                equivs.append(math_equal(extracted_model_output, ground_truth, timeout=True, pi=pi))\n            is_correct = any(equivs)\n        else:\n            is_correct = math_equal(extracted_model_output, ground_truth, timeout=True)\n    except Exception:\n        is_correct = False\n\n    return is_correct, format_correctness, extracted_model_output\n"
  },
  {
    "path": "siirl/utils/reward_score/prime_math/grader.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Copyright (c) Microsoft Corporation.\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE\n\n# Copyright (c) 2023 OpenAI\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\n# Copyright (c) 2021 Dan Hendrycks\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\n# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:\n- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py\n- https://github.com/microsoft/ProphetNet/tree/master/CRITIC\n- https://github.com/openai/prm800k\n\"\"\"\n\nimport contextlib\nimport math\nimport re\nfrom math import isclose\nfrom typing import Union\n\n# sympy related\nfrom sympy import N, simplify\nfrom sympy.parsing.latex import parse_latex\nfrom sympy.parsing.sympy_parser import parse_expr\n\n# siirl related\nfrom siirl.utils.extras.py_functional import timeout_limit\n\n\ndef is_digit(s):\n    try:\n        if \"{,}\" in str(s):\n            num = float(str(s).replace(\"{,}\", \"\"))\n            return True, num\n\n        num = float(str(s).replace(\",\", \"\"))\n        return True, num\n    except ValueError:\n        return False, None\n\n\ndef normalize(answer, pi) -> str:\n    # checking if answer is $<number> and removing $ in that case to compare\n    if isinstance(answer, str) and bool(re.match(r\"\\$\\d+(\\.\\d+)?\", answer)):\n        return answer[1:]\n\n    # checking if answer is <number>% or <number>\\\\% and removing %\n    if isinstance(answer, str) and (bool(re.match(r\"^\\d+(\\.\\d+)?%$\", answer)) or bool(re.match(r\"^\\d+(\\.\\d+)?\\\\%$\", answer))):\n        return answer.replace(\"\\\\%\", \"\").replace(\"%\", \"\")\n\n    # handle base\n    answer = handle_base(answer)\n\n    # handle pi\n    answer = handle_pi(answer, pi)\n\n    return answer\n\n\ndef handle_base(x) -> str:\n    if isinstance(x, str) and \"_\" in x:\n        # Due to base\n        x = x.split(\"_\")[0]\n        x = float(x)\n        return int(x)\n    return x\n\n\ndef handle_pi(string, pi):\n    if isinstance(string, str) and \"\\pi\" in string:\n        # Find the first occurrence of \"\\pi\"\n        idx = string.find(\"\\pi\")\n\n        # Iterate over the string and find all occurrences of \"\\pi\" with a valid previous character\n        while idx != -1:\n            if idx > 0 and string[idx - 1].isdigit():\n                # Replace \"\\pi\" with \"*math.pi\" if the previous character is a digit\n                string = string[:idx] + f\"*{pi}\" + string[idx + 3 :]\n            else:\n                # Replace \"\\pi\" with \"1*math.pi\" if the previous character is not a digit\n                string = string[:idx] + f\"1*{pi}\" + string[idx + 3 :]\n\n            # Find the next occurrence of \"\\pi\"\n            idx = string.find(\"\\pi\", idx + 1)\n\n        # Evaluate the expression using eval() function\n        with contextlib.suppress(Exception):\n            string = eval(string)\n\n    return string\n\n\ndef math_equal(\n    prediction: Union[bool, float, str],\n    reference: Union[float, str],\n    include_percentage: bool = True,\n    tolerance: float = 1e-4,\n    timeout: float = 10.0,\n    pi: float = math.pi,\n) -> bool:\n    \"\"\"\n    Exact match of math if and only if:\n    1. numerical equal: both can convert to float and are equal\n    2. symbolic equal: both can convert to sympy expression and are equal\n    \"\"\"\n\n    prediction = normalize(prediction, pi)\n    reference = normalize(reference, pi)\n\n    if isinstance(prediction, str) and len(prediction) > 1000:  # handling weird corner-cases\n        prediction = prediction[:1000]\n\n    # 0. string comparison\n    if isinstance(prediction, str) and isinstance(reference, str):\n        if prediction.strip().lower() == reference.strip().lower():\n            return True\n        if prediction.replace(\" \", \"\") == reference.replace(\" \", \"\"):\n            return True\n\n    try:  # 1. numerical equal\n        if is_digit(prediction)[0] and is_digit(reference)[0]:\n            prediction = is_digit(prediction)[1]\n            reference = is_digit(reference)[1]\n            # number questions\n            gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference]\n            for item in gt_result:\n                try:\n                    if isclose(item, prediction, rel_tol=tolerance):\n                        return True\n                except Exception:\n                    continue\n            return False\n    except Exception:\n        pass\n\n    if not prediction and prediction not in [0, False]:\n        return False\n\n    # 2. symbolic equal\n    reference = str(reference).strip()\n    prediction = str(prediction).strip()\n\n    ## deal with [], (), {}\n    prediction = format_intervals(prediction)\n\n    pred_str, ref_str = prediction, reference\n    if (prediction.startswith(\"[\") and prediction.endswith(\"]\") and not reference.startswith(\"(\")) or (prediction.startswith(\"(\") and prediction.endswith(\")\") and not reference.startswith(\"[\")):\n        pred_str = pred_str.strip(\"[]()\")\n        ref_str = ref_str.strip(\"[]()\")\n    for s in [\"{\", \"}\", \"(\", \")\"]:\n        ref_str = ref_str.replace(s, \"\")\n        pred_str = pred_str.replace(s, \"\")\n    if pred_str == ref_str:\n        return True\n\n    ## [a, b] vs. [c, d], return a==c and b==d\n    if prediction and reference and prediction[0] in \"([\" and prediction[-1] in \")]\" and prediction[0] == reference[0] and prediction[-1] == reference[-1]:\n        pred_parts = prediction[1:-1].split(\",\")\n        ref_parts = reference[1:-1].split(\",\")\n        if len(pred_parts) == len(ref_parts) and all([math_equal(pred_pt, ref_pt, include_percentage, tolerance) for pred_pt, ref_pt in zip(pred_parts, ref_parts)]):\n            return True\n\n    if \",\" in prediction and \",\" in reference:\n        pred_parts = [item.strip() for item in prediction.split(\",\")]\n        ref_parts = [item.strip() for item in reference.split(\",\")]\n\n        if len(pred_parts) == len(ref_parts):\n            return bool(all([math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance) for i in range(len(pred_parts))]))\n\n    # if we have point == tuple of values\n    if prediction.startswith(\"Point\") and reference[0] == \"(\" and reference[-1] == \")\":\n        pred_parts = prediction[prediction.find(\"(\") + 1 : -1].split(\",\")\n        ref_parts = reference[1:-1].split(\",\")\n        if len(pred_parts) == len(ref_parts) and all([math_equal(pred_pt, ref_pt, include_percentage, tolerance) for pred_pt, ref_pt in zip(pred_parts, ref_parts)]):\n            return True\n\n    # if reference is a matrix\n    if \"\\begin{pmatrix}\" in reference and prediction.startswith(\"Matrix\"):\n        try:\n            pred_matrix = parse_expr(prediction)\n            ref_matrix_items = reference.split()[1:-1:2]\n            if len(pred_matrix) == len(ref_matrix_items) and all([math_equal(pred, ref, include_percentage, tolerance) for ref, pred in zip(ref_matrix_items, pred_matrix)]):\n                return True\n        except Exception:\n            pass\n    elif \"\\begin{pmatrix}\" in reference and prediction.startswith(\"[\") and prediction.endswith(\"]\"):\n        if isinstance(eval(prediction), list):\n            try:\n                pred_matrix = eval(prediction)\n                # ref_matrix_items = reference.split()[1:-1:2]\n                ref_matrix_items = reference.lstrip(\"\\\\begin{pmatrix}\").lstrip(\"\\begin{pmatrix}\").rstrip(\"\\\\end{pmatrix}\").rstrip(\"\\end{pmatrix}\")  # noqa: B005\n                ref_matrix_items = ref_matrix_items.split(\"\\\\\")\n                ref_matrix_items = [row.split(\"&\") if \"&\" in row else row for row in ref_matrix_items]\n                if len(pred_matrix) == len(ref_matrix_items) and all([math_equal(pred, ref, include_percentage, tolerance) for ref, pred in zip(ref_matrix_items, pred_matrix)]):\n                    return True\n            except Exception:\n                pass\n\n    return symbolic_equal(prediction, reference, tolerance, timeout)\n\n\ndef symbolic_equal(a, b, tolerance, timeout=10.0):\n    def _parse(s):\n        for f in [parse_expr, parse_latex]:\n            try:\n                with timeout_limit(seconds=timeout):\n                    return f(s)\n            except TimeoutError:\n                print(f\"Parsing timed out for {s}\")\n                continue\n            except Exception:\n                continue\n        return s\n\n    a = _parse(a)\n    b = _parse(b)\n\n    try:\n        with timeout_limit(seconds=timeout):\n            if simplify(a - b) == 0:\n                return True\n    except TimeoutError:\n        print(f\"Simplification timed out for {a} - {b}\")\n        pass\n    except Exception:\n        pass\n\n    try:\n        with timeout_limit(seconds=timeout):\n            if isclose(N(a), N(b), rel_tol=tolerance):\n                return True\n    except TimeoutError:\n        print(f\"Numerical evaluation timed out for {a}, {b}\")\n        pass\n    except Exception:\n        pass\n    return False\n\n\ndef format_intervals(prediction):\n    patterns = {\n        \"Interval(\": r\"^Interval\\((.*)\\)$\",\n        \"Interval.Ropen(\": r\"^Interval\\.Ropen\\((.*)\\)$\",\n        \"Interval.Lopen(\": r\"^Interval\\.Lopen\\((.*)\\)$\",\n        \"Interval.open(\": r\"^Interval\\.open\\((.*)\\)$\",\n    }\n\n    for key, pattern in patterns.items():\n        match = re.match(pattern, prediction)\n        if match:\n            inner_content = match.group(1)\n\n            if key == \"Interval(\":  # Intarval(a, b) == [a, b]\n                return f\"[{inner_content}]\"\n            elif key == \"Interval.Ropen(\":  # Intarval.Ropen(a, b) == [a, b)\n                return f\"[{inner_content})\"\n            elif key == \"Interval.Lopen(\":  # Intarval.Lopen(a, b) == (a, b]\n                return f\"({inner_content}]\"\n            elif key == \"Interval.open(\":  # Intarval.open(a, b) == (a, b)\n                return f\"({inner_content})\"\n\n    return prediction\n"
  },
  {
    "path": "siirl/utils/reward_score/prime_math/math_normalize.py",
    "content": "# Copyright 2024 PRIME team and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Copyright (c) 2021 Dan Hendrycks\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\"\"\"\nThis logic is largely copied from the Hendrycks' MATH release (math_equivalence).\n\nFrom: https://github.com/openai/prm800k/blob/main/prm800k/grading/math_normalize.py\n\"\"\"\n\nimport re\nfrom typing import Optional\n\n\ndef normalize_answer(answer: Optional[str]) -> Optional[str]:\n    if answer is None:\n        return None\n    answer = answer.strip()\n    try:\n        # Remove enclosing `\\text{}`.\n        m = re.search(\"^\\\\\\\\text\\{(?P<text>.+?)\\}$\", answer)\n        if m is not None:\n            answer = m.group(\"text\").strip()\n        return _strip_string(answer)\n    except:  # noqa: E722\n        return answer\n\n\ndef _fix_fracs(string):\n    substrs = string.split(\"\\\\frac\")\n    new_str = substrs[0]\n    if len(substrs) > 1:\n        substrs = substrs[1:]\n        for substr in substrs:\n            new_str += \"\\\\frac\"\n            if substr[0] == \"{\":\n                new_str += substr\n            else:\n                try:\n                    assert len(substr) >= 2\n                except:  # noqa: E722\n                    return string\n                a = substr[0]\n                b = substr[1]\n                if b != \"{\":\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}{\" + b + \"}\" + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}{\" + b + \"}\"\n                else:\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}\" + b + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}\" + b\n    string = new_str\n    return string\n\n\ndef _fix_a_slash_b(string):\n    if len(string.split(\"/\")) != 2:\n        return string\n    a = string.split(\"/\")[0]\n    b = string.split(\"/\")[1]\n    try:\n        a = int(a)\n        b = int(b)\n        assert string == \"{}/{}\".format(a, b)\n        new_string = \"\\\\frac{\" + str(a) + \"}{\" + str(b) + \"}\"\n        return new_string\n    except:  # noqa: E722\n        return string\n\n\ndef _remove_right_units(string):\n    # \"\\\\text{ \" only ever occurs (at least in the val set) when describing units\n    if \"\\\\text{ \" in string:\n        splits = string.split(\"\\\\text{ \")\n        assert len(splits) == 2\n        return splits[0]\n    else:\n        return string\n\n\ndef _fix_sqrt(string):\n    if \"\\\\sqrt\" not in string:\n        return string\n    splits = string.split(\"\\\\sqrt\")\n    new_string = splits[0]\n    for split in splits[1:]:\n        if split[0] != \"{\":\n            a = split[0]\n            new_substr = \"\\\\sqrt{\" + a + \"}\" + split[1:]\n        else:\n            new_substr = \"\\\\sqrt\" + split\n        new_string += new_substr\n    return new_string\n\n\ndef _strip_string(string):\n    # linebreaks\n    string = string.replace(\"\\n\", \"\")\n\n    # remove inverse spaces\n    string = string.replace(\"\\\\!\", \"\")\n\n    # replace \\\\ with \\\n    string = string.replace(\"\\\\\\\\\", \"\\\\\")\n\n    # replace tfrac and dfrac with frac\n    string = string.replace(\"tfrac\", \"frac\")\n    string = string.replace(\"dfrac\", \"frac\")\n\n    # remove \\left and \\right\n    string = string.replace(\"\\\\left\", \"\")\n    string = string.replace(\"\\\\right\", \"\")\n\n    # Remove circ (degrees)\n    string = string.replace(\"^{\\\\circ}\", \"\")\n    string = string.replace(\"^\\\\circ\", \"\")\n\n    # remove dollar signs\n    string = string.replace(\"\\\\$\", \"\")\n\n    # remove units (on the right)\n    string = _remove_right_units(string)\n\n    # remove percentage\n    string = string.replace(\"\\\\%\", \"\")\n    string = string.replace(\"\\%\", \"\")\n\n    # \" 0.\" equivalent to \" .\" and \"{0.\" equivalent to \"{.\" Alternatively, add \"0\" if \".\" is the start of the string\n    string = string.replace(\" .\", \" 0.\")\n    string = string.replace(\"{.\", \"{0.\")\n    # if empty, return empty string\n    if len(string) == 0:\n        return string\n    if string[0] == \".\":\n        string = \"0\" + string\n\n    # to consider: get rid of e.g. \"k = \" or \"q = \" at beginning\n    if len(string.split(\"=\")) == 2 and len(string.split(\"=\")[0]) <= 2:\n        string = string.split(\"=\")[1]\n\n    # fix sqrt3 --> sqrt{3}\n    string = _fix_sqrt(string)\n\n    # remove spaces\n    string = string.replace(\" \", \"\")\n\n    # \\frac1b or \\frac12 --> \\frac{1}{b} and \\frac{1}{2}, etc. Even works with \\frac1{72} (but not \\frac{72}1). Also does a/b --> \\\\frac{a}{b}\n    string = _fix_fracs(string)\n\n    # manually change 0.5 --> \\frac{1}{2}\n    if string == \"0.5\":\n        string = \"\\\\frac{1}{2}\"\n\n    # NOTE: X/Y changed to \\frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y\n    string = _fix_a_slash_b(string)\n\n    return string\n"
  },
  {
    "path": "siirl/utils/reward_score/sandbox_fusion/__init__.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nimport logging\nimport traceback\n\nfrom .utils import check_correctness\n\n\"\"\"\nVerify code correctness using the Sandbox Fusion (https://github.com/bytedance/SandboxFusion).\nYou can either deploy the sandbox_fusion service yourself or use the\nFaaS service provided by public cloud, eg: volcengine.com.\n\"\"\"\nlogger = logging.getLogger(__name__)\n\n\ndef compute_score(sandbox_fusion_url, concurrent_semaphore, completion, test_cases, continuous=False, timeout=10):\n    \"\"\"\n    Computes the code score using the remote sandbox API.\n\n    Args:\n        sandbox_fusion_url: The URL of the sandbox_fusion service, eg: \"https://<your service endpoint>/run_code\"\n\n        completion: The completion string containing the code.\n        test_cases: JSON string or dictionary containing \"inputs\" and \"outputs\".\n        continuous: Whether to compute a continuous score (based on the first N test cases).\n        timeout: Timeout for each test case.\n\n    Returns:\n        A tuple (score, metadata_list).\n        score: Float score (0.0 to 1.0).\n        metadata_list: List containing execution metadata for each test case.\n    \"\"\"\n    solution = completion\n    if \"```python\" in completion:\n        solution = completion.split(\"```python\")[-1].split(\"```\")[0]\n    elif \"```\" in completion:\n        # Handle cases like ```\\ncode\\n```\n        parts = completion.split(\"```\")\n        if len(parts) >= 2:\n            solution = parts[1]\n            # Remove potential language specifier like 'python\\n'\n            if \"\\n\" in solution:\n                first_line, rest = solution.split(\"\\n\", 1)\n                if first_line.strip().isalpha():  # Simple check for language name\n                    solution = rest\n    else:\n        return 0.0, [{\"error\": \"Invalid completion (missing code block)\"}]\n\n    try:\n        if not isinstance(test_cases, dict):\n            try:\n                test_cases = json.loads(test_cases)\n            except json.JSONDecodeError as e:\n                logger.error(f\"Failed to parse test_cases JSON: {e}\")\n                return 0.0, [{\"error\": \"Invalid test_cases JSON format\"}]\n\n        if not test_cases or \"inputs\" not in test_cases or \"outputs\" not in test_cases:\n            logger.error(\"Invalid test_cases structure.\")\n            return 0.0, [{\"error\": \"Invalid test_cases structure (missing inputs/outputs)\"}]\n\n        # Check all test cases\n        # Note: The return value of check_correctness might need adaptation here\n        # Assume check_correctness returns (results_list, metadata_list)\n        # results_list contains True, False, or error codes (-1, -2, -3, etc.)\n        res_list, metadata_list = check_correctness(sandbox_fusion_url=sandbox_fusion_url, in_outs=test_cases, generation=solution, timeout=timeout, concurrent_semaphore=concurrent_semaphore)\n\n        # Calculate score\n        if not res_list:  # If there are no results (e.g., invalid input)\n            return 0.0, metadata_list\n\n        if continuous:\n            # Calculate pass rate for the first N (e.g., 10) test cases\n            num_to_consider = min(len(res_list), 10)\n            if num_to_consider == 0:\n                score = 0.0\n            else:\n                passed_count = sum(1 for r in res_list[:num_to_consider] if r is True)\n                score = passed_count / num_to_consider\n            # Return all metadata, even if score is based on the first N\n            final_metadata = metadata_list\n        else:\n            # Calculate pass rate for all test cases\n            passed_count = sum(1 for r in res_list if r is True)\n            total_cases = len(res_list)\n            score = passed_count / total_cases if total_cases > 0 else 0.0\n            final_metadata = metadata_list\n\n    except Exception as e:\n        logger.error(f\"Error during compute_score: {e}\")\n        traceback.print_exc()\n        score = 0.0\n        # Try to return partial metadata if available, otherwise return error info\n        final_metadata = metadata_list if \"metadata_list\" in locals() else [{\"error\": f\"Unhandled exception: {e}\"}]\n\n    # Ensure float and list are returned\n    return float(score), final_metadata if isinstance(final_metadata, list) else [final_metadata]\n"
  },
  {
    "path": "siirl/utils/reward_score/sandbox_fusion/utils.py",
    "content": "# Copyright 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport concurrent.futures  # <-- Import concurrent.futures\nimport json\nimport logging\nimport os\nimport threading\nimport time\nimport traceback\nimport uuid\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport requests\n\nDEFAULT_TIMEOUT = 10  # Default compile and run timeout\nMAX_RETRIES = 3\nINITIAL_RETRY_DELAY = 1\nAPI_TIMEOUT = 10\n\nlogger = logging.getLogger(__name__)\n\n# Define supported languages list (optional, for documentation or validation)\nSUPPORTED_LANGUAGES = [\"python\", \"cpp\", \"nodejs\", \"go\", \"go_test\", \"java\", \"php\", \"csharp\", \"bash\", \"typescript\", \"sql\", \"rust\", \"cuda\", \"lua\", \"R\", \"perl\", \"D_ut\", \"ruby\", \"scala\", \"julia\", \"pytest\", \"junit\", \"kotlin_script\", \"jest\", \"verilog\", \"python_gpu\", \"lean\", \"swift\", \"racket\"]\n\n\ndef call_sandbox_api(sandbox_fusion_url: str, code: str, stdin: str, compile_timeout: int, run_timeout: int, language: str = \"python\") -> Tuple[Optional[Dict[str, Any]], Optional[str]]:  # <-- Remove request_id parameter\n    \"\"\"\n    Calls the remote sandbox API to execute code with retry logic for Gateway Timeout,\n    using increasing delay between retries. Logs internal calls with a unique ID.\n\n    Args:\n        sandbox_fusion_url: The URL of the sandbox fusion API.\n        code: The code string to execute.\n        stdin: The standard input string.\n        compile_timeout: Compile timeout in seconds.\n        run_timeout: Run timeout in seconds.\n        language: The programming language of the code (e.g., \"python\", \"cpp\", \"java\"). Defaults to \"python\".\n\n    Returns:\n        A tuple (response_json, error_message).\n        If successful, response_json is the API's returned JSON object, error_message is None.\n        If failed after retries, response_json is None, error_message contains the error information.\n    \"\"\"\n    request_id = str(uuid.uuid4())  # <-- Generate request_id internally\n    log_prefix = f\"[Request ID: {request_id}] \"  # <-- Create log prefix\n\n    if language not in SUPPORTED_LANGUAGES:\n        error_msg = f\"{log_prefix}Unsupported language: {language}\"\n        logger.error(error_msg)\n        return None, error_msg\n\n    payload = json.dumps(\n        {\n            \"compile_timeout\": compile_timeout,\n            \"run_timeout\": run_timeout,\n            \"code\": code,\n            \"stdin\": stdin,\n            \"language\": language,  # Use the passed language parameter\n            \"files\": {},\n            \"fetch_files\": [],\n        }\n    )\n    headers = {\"Content-Type\": \"application/json\", \"Accept\": \"application/json\"}\n    # Calculate a reasonable request timeout based on compile/run timeouts plus a buffer\n    request_timeout = compile_timeout + run_timeout + API_TIMEOUT\n\n    last_error = None  # Store the last error encountered\n\n    for attempt in range(MAX_RETRIES):\n        try:\n            logger.info(f\"{log_prefix}Attempt {attempt + 1}/{MAX_RETRIES}: Calling sandbox API at {sandbox_fusion_url}\")  # <-- Use internal log_prefix\n            response = requests.post(\n                sandbox_fusion_url,\n                headers=headers,\n                data=payload,\n                timeout=request_timeout,  # Use the calculated timeout\n            )\n\n            # Check for Gateway Timeout (504) specifically for retrying\n            if response.status_code == 504:\n                last_error = f\"{log_prefix}API Request Error: Gateway Timeout (504) on attempt {attempt + 1}/{MAX_RETRIES}\"  # <-- Use internal log_prefix\n                logger.warning(last_error)\n                if attempt < MAX_RETRIES - 1:  # Don't sleep after the last attempt\n                    # Calculate increasing delay (e.g., 1s, 2s, 4s, ...) or (1s, 2s, 3s, ...)\n                    # Simple linear increase: delay = INITIAL_RETRY_DELAY * (attempt + 1)\n                    # Exponential backoff: delay = INITIAL_RETRY_DELAY * (2 ** attempt)\n                    delay = INITIAL_RETRY_DELAY * (attempt + 1)  # Using linear increase for simplicity\n                    logger.info(f\"{log_prefix}Retrying after {delay} seconds...\")  # <-- Use internal log_prefix\n                    time.sleep(delay)\n                continue  # Go to the next retry attempt\n\n            # Check for other HTTP errors (e.g., 4xx, other 5xx)\n            response.raise_for_status()\n\n            # If successful (status code 2xx)\n            logger.info(f\"{log_prefix}Sandbox API call successful on attempt {attempt + 1}\")  # <-- Use internal log_prefix\n            return response.json(), None\n\n        except requests.exceptions.RequestException as e:\n            last_error = f\"{log_prefix}API Request Error: {e}\"  # <-- Use internal log_prefix\n            break  # Exit retry loop on non-504 request errors\n        except json.JSONDecodeError as e:\n            raw_response_text = response.text if \"response\" in locals() else \"N/A\"\n            last_error = f\"{log_prefix}API Response JSON Decode Error: {e}\"  # <-- Use internal log_prefix\n            break  # Exit retry loop on JSON decode errors\n        except Exception as e:\n            last_error = f\"{log_prefix}Unexpected Error: {e}\"  # <-- Use internal log_prefix\n            break  # Exit retry loop on other unexpected errors\n\n    # If loop finishes without returning success, return the last recorded error\n    logger.error(f\"{log_prefix}Sandbox API call failed. Last error: {last_error}\")  # <-- Use internal log_prefix\n    # Return the error message without the prefix, as the caller doesn't need the internal ID\n    # Ensure API call failure returns error message, leading to -1 in check_correctness\n    return None, last_error.replace(log_prefix, \"API Call Failed: \") if last_error else \"API Call Failed after retries\"\n\n\ndef _process_single_case(case_index: int, stdin_data: Any, expected_output: Any, sandbox_fusion_url: str, generation: str, timeout: int, language: str, concurrent_semaphore: Optional[threading.Semaphore] = None, fn_name: Optional[str] = None) -> Tuple[int, Dict[str, Any]]:\n    \"\"\"Helper function to process a single test case.\"\"\"\n    api_response = None\n    error_msg = None\n    logger.info(f\"Processing test case {case_index + 1}.\")\n\n    current_generation_code = generation\n\n    if fn_name and language == \"python\":\n        # Wrapper assumes stdin_data is a JSON string for function arguments.\n        wrapper_code = f\"\"\"\nimport traceback\nfrom string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\n\n# === User's Original Code START ===\n{generation}\n# === User's Original Code END ===\n\n_SANDBOX_FN_NAME = \"{fn_name}\"\n\ndef _execute_user_function():\n    # --- Input Parsing ---\n    _raw_input_str = sys.stdin.read()\n    _args = []\n    if _raw_input_str.strip(): # If there's input\n        try:\n            _args = [json.loads(line) for line in _raw_input_str.split('\\\\n')]\n        except json.JSONDecodeError as _je:\n            sys.stderr.write(f\"WrapperError: Invalid JSON input for '{{_SANDBOX_FN_NAME}}': {{_je}}\\\\nInput was: {{_raw_input_str[:200]}}\\\\n\")\n            return None, True # result, error_occurred\n\n    # --- Function Location and Execution ---\n    try:\n        _target_callable = None\n        # Try global scope first\n        if _SANDBOX_FN_NAME in globals():\n            _target_callable = globals()[_SANDBOX_FN_NAME]\n        # Else, if 'Solution' class exists, try to get its method\n        elif 'Solution' in globals():\n            _Solution_class = globals()['Solution']\n            # Attempt to instantiate and get method.\n            # Errors (e.g., Solution not a class, instantiation fails, method missing)\n            # will be caught by the broad except block below.\n            _solution_instance = _Solution_class() \n            _target_callable = getattr(_solution_instance, _SANDBOX_FN_NAME)\n        \n        if not _target_callable:\n            sys.stderr.write(f\"WrapperError: Function or method '{{_SANDBOX_FN_NAME}}' not found.\\\\n\")\n            return None, True # result, error_occurred\n\n        _fn_result = _target_callable(*_args)\n        return _fn_result, False # result, no_error\n    except Exception: # Catches errors from Solution instantiation, getattr, or function call\n        sys.stderr.write(f\"Error during setup or execution of '{{_SANDBOX_FN_NAME}}':\\\\n{{traceback.format_exc()}}\\\\n\")\n        return None, True # result, error_occurred\n\nif __name__ == '__main__':\n    _result, _error_occurred = _execute_user_function()\n\n    if not _error_occurred:\n        # Serialize result to stdout\n        if isinstance(_result, (dict, list, tuple)) or _result is None or isinstance(_result, bool):\n            print(json.dumps(_result))\n        elif isinstance(_result, (int, float, str)):\n            print(str(_result)) # Ensure string conversion for print\n        else:\n            # For other types, default to string representation.\n            print(str(_result))\n    # Optional: To explicitly exit with an error code if the sandbox relies on it\n    # else:\n    #    sys.exit(1) \n\"\"\"\n        current_generation_code = wrapper_code\n\n    try:\n        if concurrent_semaphore:\n            # logger.debug(f\"Case {case_index + 1}: Attempting to acquire semaphore.\")\n            with concurrent_semaphore:\n                # logger.debug(f\"Case {case_index + 1}: Semaphore acquired. Calling API.\")\n                api_response, error_msg = call_sandbox_api(sandbox_fusion_url=sandbox_fusion_url, code=current_generation_code, stdin=str(stdin_data), compile_timeout=timeout, run_timeout=timeout, language=language)\n            # logger.debug(f\"Case {case_index + 1}: Semaphore released.\")\n        else:\n            api_response, error_msg = call_sandbox_api(sandbox_fusion_url=sandbox_fusion_url, code=current_generation_code, stdin=str(stdin_data), compile_timeout=timeout, run_timeout=timeout, language=language)\n    except Exception as e:\n        error_msg = f\"API Request Exception during check_correctness for case {case_index + 1}: {e}\"\n        logger.error(f\"Case {case_index + 1}: {error_msg}\")\n        traceback.print_exc()\n\n    metadata = {\n        \"case_index\": case_index,\n        \"input\": str(stdin_data),\n        \"expected_output\": str(expected_output),\n        \"api_request_error\": error_msg,\n        \"api_response\": None,\n        \"status\": \"unknown\",\n        \"stdout\": None,\n        \"stderr\": None,\n        \"exit_code\": None,\n        \"duration\": None,\n        \"compile_duration\": None,\n        \"compile_stderr\": None,\n        \"api_status\": None,\n        \"compile_status\": None,\n        \"run_status\": None,\n    }\n    result_status = -1  # Default error: API request error or unknown sandbox error\n\n    if error_msg:\n        metadata[\"status\"] = \"api_error\"\n        result_status = -1  # API request itself failed (includes timeout after retries)\n        logger.error(f\"Case {case_index}: API error occurred: {error_msg}\")\n        # Log code and input only on error for brevity\n        generation_to_log = generation[:200] + \"...\" if len(generation) > 200 else generation\n        logger.error(f\"Case {case_index}: code: {generation_to_log}\")\n        logger.error(f\"Case {case_index}: input: {str(stdin_data)}\")\n    elif api_response:\n        # --- Add debug logging ---\n        logger.debug(f\"Case {case_index}: API Response: {api_response}\")\n        metadata[\"api_response\"] = api_response\n        metadata[\"api_status\"] = api_response.get(\"status\")\n        compile_result = api_response.get(\"compile_result\")\n        run_result = api_response.get(\"run_result\")\n\n        # Extract compile information\n        if compile_result:\n            metadata[\"compile_status\"] = compile_result.get(\"status\")\n            metadata[\"compile_duration\"] = compile_result.get(\"execution_time\")\n            metadata[\"compile_stderr\"] = compile_result.get(\"stderr\")\n\n        # Extract run information\n        if run_result:\n            metadata[\"run_status\"] = run_result.get(\"status\")\n            metadata[\"stdout\"] = run_result.get(\"stdout\")\n            metadata[\"stderr\"] = run_result.get(\"stderr\")  # stderr during runtime\n            metadata[\"exit_code\"] = run_result.get(\"return_code\")\n            metadata[\"duration\"] = run_result.get(\"execution_time\")\n\n        # --- Determine status based on API response ---\n        api_status = metadata[\"api_status\"]\n\n        if api_status == \"SandboxError\":\n            metadata[\"status\"] = \"sandbox_error\"\n            result_status = -1  # Internal sandbox error\n        elif api_status == \"Failed\":\n            # --- Add debug logging ---\n            logger.debug(f\"API returned Failed status. Response: {api_response}\")\n            logger.debug(f\"Compile Result: {compile_result}\")\n            logger.debug(f\"Run Result: {run_result}\")\n            # --- Check the logic here ---\n            # Compile failed or timed out\n            is_compile_error = compile_result and (metadata[\"compile_status\"] in [\"Error\", \"TimeLimitExceeded\"] or (metadata[\"compile_status\"] == \"Finished\" and compile_result.get(\"return_code\") != 0))\n            if is_compile_error:\n                # Differentiate between compile_error and compile_timeout based on specific status\n                if metadata[\"compile_status\"] == \"TimeLimitExceeded\":\n                    metadata[\"status\"] = \"compile_timeout\"\n                else:  # Includes Error and Finished but return_code != 0 cases\n                    metadata[\"status\"] = \"compile_error\"\n                result_status = -4\n            # Run failed or timed out\n            elif run_result:\n                # Modified condition: Check for TimeLimitExceeded OR (Finished with non-zero exit code) OR Error status\n                is_runtime_error = metadata[\"run_status\"] == \"TimeLimitExceeded\" or metadata[\"run_status\"] == \"Error\" or (metadata[\"run_status\"] == \"Finished\" and run_result.get(\"return_code\") != 0)\n                if is_runtime_error:\n                    if metadata[\"run_status\"] == \"TimeLimitExceeded\":\n                        metadata[\"status\"] = \"timeout\"  # Runtime timeout\n                        result_status = -3\n                    else:  # Includes Error and Finished with non-zero return_code\n                        metadata[\"status\"] = \"runtime_error\"\n                        result_status = -2\n                else:\n                    # Other Failed status with run_result, classify as unknown failure\n                    logger.warning(f\"Unknown run_status '{metadata['run_status']}' or state within Failed API status.\")\n                    metadata[\"status\"] = \"unknown_failure\"\n                    result_status = -1  # Default to -1\n            else:\n                # Status is Failed but neither a clear compile error nor run_result exists\n                logger.warning(\"API status Failed but cannot determine specific error type (compile/run).\")\n                metadata[\"status\"] = \"unknown_failure_state\"\n                result_status = -1  # Default to -1\n        elif api_status == \"Success\":\n            # Run completed successfully, now check the answer\n            if run_result and metadata[\"run_status\"] == \"Finished\":\n                actual_output = metadata[\"stdout\"] if metadata[\"stdout\"] is not None else \"\"\n                # Note: Output might contain trailing newlines, need normalization\n                if str(actual_output).rstrip(\"\\n\") == str(expected_output).rstrip(\"\\n\"):\n                    result_status = True\n                    metadata[\"status\"] = \"success\"\n                else:\n                    result_status = False\n                    metadata[\"status\"] = \"wrong_answer\"\n            else:\n                # Status is Success but run_result status is not Finished, this is unexpected\n                metadata[\"status\"] = \"unexpected_success_state\"\n                result_status = -1  # Classify as unknown error\n        else:\n            # API returned an unknown top-level status\n            logger.warning(f\"Unknown API status received: {api_status}\")\n            metadata[\"status\"] = f\"unknown_api_status_{api_status}\"\n            result_status = -1  # Default to -1\n    else:  # api_response is None and no error_msg (Should not happen with current call_sandbox_api logic)\n        metadata[\"status\"] = \"unknown_api_state\"\n        result_status = -1\n        logger.error(f\"Case {case_index}: Unknown API state (no response and no error message).\")\n    return result_status, metadata\n\n\ndef check_correctness(sandbox_fusion_url: str, in_outs: Optional[dict], generation: str, timeout: int = DEFAULT_TIMEOUT, language: str = \"python\", concurrent_semaphore: Optional[threading.Semaphore] = None) -> Tuple[List[Any], List[Dict[str, Any]]]:\n    \"\"\"\n    Checks the correctness of code generation using the remote sandbox API,\n    processing test cases concurrently.\n\n    Args:\n        sandbox_fusion_url: The URL of the sandbox fusion API.\n        in_outs: Dictionary containing \"inputs\" and \"outputs\" lists.\n        generation: The generated code string.\n        timeout: Timeout for each test case (compile and run share this timeout).\n        language: The programming language of the code.\n\n    Returns:\n        A tuple (results, metadata_list).\n        results: A list containing the test result for each input/output pair\n                 (True/False/-1 api/sandbox err, -2 runtime err, -3 timeout, -4 compile err).\n                 Results are ordered corresponding to the inputs.\n        metadata_list: A list containing metadata dictionaries for each test case,\n                       ordered corresponding to the inputs.\n    \"\"\"\n    logger.info(\"Starting correctness check for generation.\")\n\n    if not in_outs or \"inputs\" not in in_outs or \"outputs\" not in in_outs:\n        logger.warning(\"Invalid in_outs format provided.\")\n        return [-1], [{\"error\": \"Invalid input/output data\"}]\n\n    inputs = in_outs[\"inputs\"]\n    expected_outputs = in_outs[\"outputs\"]\n    fn_name = in_outs.get(\"fn_name\")\n    num_cases = len(inputs)\n    results = [None] * num_cases  # Initialize with placeholders\n    metadata_list = [None] * num_cases  # Initialize with placeholders\n\n    if num_cases == 0:\n        logger.warning(\"Empty inputs provided.\")\n        return [], []\n\n    if len(inputs) != len(expected_outputs):\n        logger.warning(f\"Mismatch between number of inputs ({len(inputs)}) and outputs ({len(expected_outputs)}).\")\n        # Return error based on the number of inputs provided\n        return [-1] * num_cases, [{\"error\": \"Input/output count mismatch\", \"case_index\": i} for i in range(num_cases)]\n\n    first_compile_error_index = -1\n\n    # max_workers is limited by sandbox_fusion_max_concurrent from concurrent_semaphore\n    with concurrent.futures.ThreadPoolExecutor(max_workers=max(32, os.cpu_count() * 5)) as executor:\n        # Submit all tasks, passing the concurrent_semaphore to _process_single_case\n        future_to_index = {executor.submit(_process_single_case, i, stdin_data, expected_outputs[i], sandbox_fusion_url, generation, timeout, language, concurrent_semaphore, fn_name): i for i, stdin_data in enumerate(inputs)}\n\n        # Process results as they complete\n        for future in concurrent.futures.as_completed(future_to_index):\n            index = future_to_index[future]\n            try:\n                result_status, metadata = future.result()\n                results[index] = result_status\n                metadata_list[index] = metadata\n\n                # Check for compile error (-4)\n                if result_status == -4:\n                    if first_compile_error_index == -1 or index < first_compile_error_index:\n                        first_compile_error_index = index\n                    # Optimization: could potentially cancel futures for index > first_compile_error_index\n                    # However, cancellation is not guaranteed. Post-processing is safer.\n\n            except Exception as exc:\n                logger.error(f\"Test case {index} generated an exception: {exc}\")\n                traceback.print_exc()\n                results[index] = -1  # Mark as API/internal error\n                metadata_list[index] = {\n                    \"case_index\": index,\n                    \"input\": str(inputs[index]),\n                    \"expected_output\": str(expected_outputs[index]),\n                    \"api_request_error\": f\"Internal execution error: {exc}\",\n                    \"status\": \"internal_error\",\n                }\n\n    # Post-processing for compile errors\n    if first_compile_error_index != -1:\n        logger.warning(f\"Compile error detected in case {first_compile_error_index}. Marking subsequent cases as compile errors.\")\n        for i in range(first_compile_error_index + 1, num_cases):\n            # Only update if not already processed (though it should be None or have a result)\n            if results[i] != -4:  # Avoid overwriting if it somehow already got -4\n                results[i] = -4\n                # Update or create metadata for skipped cases due to compile error\n                if metadata_list[i] is None:  # If future failed before returning metadata\n                    metadata_list[i] = {\n                        \"case_index\": i,\n                        \"input\": str(inputs[i]),\n                        \"expected_output\": str(expected_outputs[i]),\n                        \"api_request_error\": None,\n                        \"status\": \"compile_error_skipped\",  # Indicate skipped due to prior compile error\n                    }\n                else:  # If future completed but result is overridden\n                    metadata_list[i][\"status\"] = \"compile_error_skipped\"\n\n    logger.info(f\"Correctness check finished. Results: {results}\")\n    return results, metadata_list\n"
  },
  {
    "path": "siirl/utils/reward_score/search_r1_like_qa_em.py",
    "content": "# Copyright 2024 Bytedance Ltd. and/or its affiliates\r\n# Copyright 2023-2024 SGLang Team\r\n# Copyright 2025 Search-R1 Contributors\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n# Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/verl/utils/reward_score/qa_em.py\r\n\r\nimport random\r\nimport re\r\nimport string\r\n\r\n\r\ndef normalize_answer(s):\r\n    def remove_articles(text):\r\n        return re.sub(r\"\\b(a|an|the)\\b\", \" \", text)\r\n\r\n    def white_space_fix(text):\r\n        return \" \".join(text.split())\r\n\r\n    def remove_punc(text):\r\n        exclude = set(string.punctuation)\r\n        return \"\".join(ch for ch in text if ch not in exclude)\r\n\r\n    def lower(text):\r\n        return text.lower()\r\n\r\n    return white_space_fix(remove_articles(remove_punc(lower(s))))\r\n\r\n\r\ndef em_check(prediction, golden_answers):\r\n    if isinstance(golden_answers, str):\r\n        golden_answers = [golden_answers]\r\n    normalized_prediction = normalize_answer(prediction)\r\n    score = 0\r\n    for golden_answer in golden_answers:\r\n        golden_answer = normalize_answer(golden_answer)\r\n        if golden_answer == normalized_prediction:\r\n            score = 1\r\n            break\r\n    return score\r\n\r\n\r\ndef subem_check(prediction, golden_answers):\r\n    if isinstance(golden_answers, str):\r\n        golden_answers = [golden_answers]\r\n    normalized_prediction = normalize_answer(prediction)\r\n    score = 0\r\n    for golden_answer in golden_answers:\r\n        golden_answer = normalize_answer(golden_answer)\r\n        if golden_answer in normalized_prediction:\r\n            score = 1\r\n            break\r\n    return score\r\n\r\n\r\ndef extract_solution(solution_str):\r\n    \"\"\"Extract the equation from the solution string.\"\"\"\r\n    # Remove everything before the first \"Assistant:\"\r\n    # if \"Assistant:\" in solution_str:\r\n    #     solution_str = solution_str.split(\"Assistant:\", 1)[1]\r\n    # elif \"<|im_start|>assistant\" in solution_str:\r\n    #     solution_str = solution_str.split(\"<|im_start|>assistant\", 1)[1]\r\n    # else:\r\n    #     return None\r\n    # solution_str = solution_str.split('\\n')[-1]\r\n\r\n    answer_pattern = r\"<answer>(.*?)</answer>\"\r\n    match = re.finditer(answer_pattern, solution_str, re.DOTALL)\r\n    matches = list(match)\r\n\r\n    # If there are 0  matches, return None\r\n    if len(matches) < 1:\r\n        return None\r\n\r\n    # If there are 2 or more matches, return the last one\r\n    return matches[-1].group(1).strip()\r\n\r\n\r\ndef count_answer_tags(text):\r\n    opening_tags = text.count(\"<answer>\")\r\n    closing_tags = text.count(\"</answer>\")\r\n\r\n    return opening_tags, closing_tags\r\n\r\n\r\ndef compute_score(solution_str, ground_truth, method=\"strict\", format_score=0.0, score=1.0):\r\n    \"\"\"The scoring function for exact match (EM).\r\n\r\n    Args:\r\n        solution_str: the solution text\r\n        ground_truth: the ground truth\r\n        method: the method to extract the solution, choices are 'strict' and 'flexible'\r\n        format_score: the score for the format\r\n        score: the score for the correct answer\r\n    \"\"\"\r\n    answer = extract_solution(solution_str=solution_str)\r\n    open_count, close_count = count_answer_tags(solution_str)\r\n    do_print = random.randint(1, 64) == 1\r\n\r\n    if do_print:\r\n        print(\"--------------------------------\")\r\n        print(f\"Golden answers: {ground_truth['target']}\")\r\n        if answer is not None:\r\n            print(f\"Extracted answer is not None: {answer}\")\r\n        else:\r\n            print(\"Extracted answer: None!\")\r\n        print(f\"Solution string: {solution_str}\")\r\n\r\n    if answer is None:\r\n        return 0\r\n    else:\r\n        if em_check(answer, ground_truth[\"target\"]):\r\n            if open_count > 10 or close_count > 10:  # prevent output a lot of </answer>\r\n                score = score / 4\r\n                return score\r\n            return score\r\n        else:\r\n            return format_score\r\n\r\n\r\ndef compute_score_subem(solution_str, ground_truth, method=\"strict\", format_score=0.0, score=1.0):\r\n    \"\"\"The scoring function for substring exact match (EM).\r\n\r\n    Args:\r\n        solution_str: the solution text\r\n        ground_truth: the ground truth\r\n        method: the method to extract the solution, choices are 'strict' and 'flexible'\r\n        format_score: the score for the format\r\n        score: the score for the correct answer\r\n    \"\"\"\r\n    answer = extract_solution(solution_str=solution_str)\r\n    do_print = random.randint(1, 64) == 1\r\n\r\n    if do_print:\r\n        print(\"--------------------------------\")\r\n        print(f\"Golden answers: {ground_truth['target']}\")\r\n        print(f\"Extracted answer: {answer}\")\r\n        print(f\"Solution string: {solution_str}\")\r\n\r\n    if answer is None:\r\n        return 0\r\n    else:\r\n        if subem_check(answer, ground_truth[\"target\"]):\r\n            return score\r\n        else:\r\n            return format_score\r\n"
  },
  {
    "path": "tests/__init__.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License."
  },
  {
    "path": "tests/dag/test_config_loader.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib\nimport json\nimport os\nimport sys\nimport unittest\nfrom contextlib import contextmanager\nfrom unittest.mock import patch\n\nfrom loguru import logger\n\nfrom siirl.workers.dag import DAGConfigLoader, Node, NodeRole, NodeType, TaskGraph\nfrom siirl.utils.params import parse_config, ProfilerArguments\nfrom siirl.utils.debug import DistProfiler\n\n\n@contextmanager\ndef capture_loguru_logs(level=\"INFO\"):\n    \"\"\"Capture loguru-based logs.\"\"\"\n    logs = []\n    handler_id = logger.add(lambda msg: logs.append(msg), level=level)\n    yield logs\n    logger.remove(handler_id)\n\n\n# These functions simulate the actual tasks that nodes would execute.\n# They are defined as async to match the Node.run signature.\ndummy_tasks_content = \"\"\"\nimport asyncio\n\nasync def load_data_node(node_config):\n    return {\"data\": \"loaded_data\"}\n\nasync def preprocess_data_node(input_data, node_config):\n    return {\"processed_data\": \"processed_data\"}\n\nasync def train_model(input_data, node_config):\n    return {\"model_weights\": \"trained_model\"}\n\nasync def evaluate_ensemble_node(input_data, node_config):\n    return {\"evaluation_results\": \"eval_results\"}\n\nasync def prepare_and_sync_barrier_data(input_data, node_config):\n    return {\"sync_status\": \"synced\"}\n\nnon_callable_object = \"I am not callable\"\n\nasync def missing_param_func():\n    return \"This function expects no params\"\n\"\"\"\n\n\nclass TestDAGConfigLoader(unittest.TestCase):\n    \"\"\"\n    Unit tests for the DAGConfigLoader class.\n    \"\"\"\n\n    @classmethod\n    def setUpClass(cls):\n        \"\"\"\n        Set up for all tests: Create a dummy my_tasks.py module.\n        \"\"\"\n        cls.dummy_tasks_file = \"sii_rl_test_tasks.py\"\n        with open(cls.dummy_tasks_file, \"w\", encoding=\"utf-8\") as f:\n            f.write(dummy_tasks_content)\n        # Add the directory containing the dummy module to sys.path\n        sys.path.insert(0, os.path.dirname(os.path.abspath(cls.dummy_tasks_file)))\n        # Reload importlib's module cache to ensure it finds the new module\n        importlib.invalidate_caches()\n\n    @classmethod\n    def tearDownClass(cls):\n        \"\"\"\n        Clean up after all tests: Remove the dummy my_tasks.py module.\n        \"\"\"\n        if os.path.exists(cls.dummy_tasks_file):\n            os.remove(cls.dummy_tasks_file)\n        # Remove the dummy module's directory from sys.path\n        if sys.path[0] == os.path.dirname(os.path.abspath(cls.dummy_tasks_file)):\n            sys.path.pop(0)\n        # Clean up any loaded dummy module from sys.modules\n        if \"sii_rl_test_tasks\" in sys.modules:\n            del sys.modules[\"sii_rl_test_tasks\"]\n\n    def setUp(self):\n        \"\"\"\n        Set up before each test: Initialize DAGConfigLoader.\n        \"\"\"\n        self.loader = DAGConfigLoader()\n        self.yaml_file = \"test_dag.yaml\"\n        self.json_file = \"test_dag.json\"\n\n    def tearDown(self):\n        \"\"\"\n        Clean up after each test: Remove temporary config files.\n        \"\"\"\n        if os.path.exists(self.yaml_file):\n            os.remove(self.yaml_file)\n        if os.path.exists(self.json_file):\n            os.remove(self.json_file)\n\n    def _create_config_file(self, content: str, file_type: str = \"yaml\"):\n        \"\"\"Helper to create a temporary config file.\"\"\"\n        file_path = self.yaml_file if file_type == \"yaml\" else self.json_file\n        with open(file_path, \"w\", encoding=\"utf-8\") as f:\n            f.write(content)\n        return file_path\n\n    def test_load_valid_yaml_config(self):\n        \"\"\"Test loading a valid YAML configuration.\"\"\"\n        yaml_content = \"\"\"\n        dag_id: \"test_dag_yaml\"\n        description: \"A valid YAML DAG\"\n        global_config:\n          param1: value1\n        nodes:\n          - node_id: \"start_node\"\n            node_type: \"DATA_LOAD\"\n            dependencies: []\n            executable_ref: \"sii_rl_test_tasks.load_data_node\"\n          - node_id: \"middle_node\"\n            node_type: \"COMPUTE\"\n            dependencies: [\"start_node\"]\n            config:\n              compute_param: 100\n            executable_ref: \"sii_rl_test_tasks.preprocess_data_node\"\n          - node_id: \"end_node\"\n            node_type: \"MODEL_TRAIN\"\n            dependencies: [\"middle_node\"]\n            node_role: \"ACTOR\"\n            executable_ref: \"sii_rl_test_tasks.train_model\"\n        \"\"\"\n        file_path = self._create_config_file(yaml_content, \"yaml\")\n        task_graph = self.loader.load_from_file(file_path, \"yaml\")\n\n        self.assertIsInstance(task_graph, TaskGraph)\n        self.assertEqual(task_graph.graph_id, \"test_dag_yaml\")\n        self.assertEqual(len(task_graph.nodes), 3)\n\n        start_node = task_graph.get_node(\"start_node\")\n        self.assertIsNotNone(start_node)\n        self.assertEqual(start_node.node_type, NodeType.DATA_LOAD)\n        self.assertEqual(start_node.node_role, NodeRole.DEFAULT)\n        self.assertEqual(start_node.dependencies, [])\n        self.assertIsNotNone(start_node.executable)\n        self.assertEqual(start_node.executable.__name__, \"load_data_node\")\n\n        middle_node = task_graph.get_node(\"middle_node\")\n        self.assertIsNotNone(middle_node)\n        self.assertEqual(middle_node.node_type, NodeType.COMPUTE)\n        self.assertEqual(middle_node.dependencies, [\"start_node\"])\n        self.assertEqual(middle_node.config, {\"compute_param\": 100, \"_node_id_\": \"middle_node\"})\n        self.assertIsNotNone(middle_node.executable)\n        self.assertEqual(middle_node.executable.__name__, \"preprocess_data_node\")\n\n        end_node = task_graph.get_node(\"end_node\")\n        self.assertIsNotNone(end_node)\n        self.assertEqual(end_node.node_type, NodeType.MODEL_TRAIN)\n        self.assertEqual(end_node.node_role, NodeRole.ACTOR)\n        self.assertEqual(end_node.dependencies, [\"middle_node\"])\n        self.assertIsNotNone(end_node.executable)\n        self.assertEqual(end_node.executable.__name__, \"train_model\")\n\n        # Test topological sort\n        topological_order = task_graph.get_topological_sort()\n        self.assertEqual(topological_order, [\"start_node\", \"middle_node\", \"end_node\"])\n\n    def test_load_valid_json_config(self):\n        \"\"\"Test loading a valid JSON configuration.\"\"\"\n        json_content = json.dumps(\n            {\n                \"dag_id\": \"test_dag_json\",\n                \"description\": \"A valid JSON DAG\",\n                \"global_config\": {\"json_param\": \"json_value\"},\n                \"nodes\": [{\"node_id\": \"node_A\", \"node_type\": \"COMPUTE\", \"dependencies\": [], \"executable_ref\": \"sii_rl_test_tasks.load_data_node\"}, {\"node_id\": \"node_B\", \"node_type\": \"COMPUTE\", \"dependencies\": [\"node_A\"], \"executable_ref\": \"sii_rl_test_tasks.preprocess_data_node\"}],\n            }\n        )\n        file_path = self._create_config_file(json_content, \"json\")\n        task_graph = self.loader.load_from_file(file_path, \"json\")\n\n        self.assertIsInstance(task_graph, TaskGraph)\n        self.assertEqual(task_graph.graph_id, \"test_dag_json\")\n        self.assertEqual(len(task_graph.nodes), 2)\n        self.assertEqual(task_graph.get_node(\"node_A\").node_type, NodeType.COMPUTE)\n        self.assertEqual(task_graph.get_node(\"node_B\").dependencies, [\"node_A\"])\n\n    def test_missing_dag_id(self):\n        \"\"\"Test error when 'dag_id' is missing.\"\"\"\n        yaml_content = \"\"\"\n        description: \"Missing ID\"\n        nodes:\n          - node_id: \"node1\"\n            node_type: \"COMPUTE\"\n            dependencies: []\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        with self.assertRaisesRegex(ValueError, \"The 'dag_id' is missing in the configuration\"):\n            self.loader.load_from_file(file_path, \"yaml\")\n\n    def test_missing_nodes_list(self):\n        \"\"\"Test error when 'nodes' list is missing.\"\"\"\n        yaml_content = \"\"\"\n        dag_id: \"test_missing_nodes\"\n        description: \"Missing nodes list\"\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        with self.assertRaisesRegex(ValueError, \"The 'nodes' list is missing in the DAG configuration\"):\n            self.loader.load_from_file(file_path, \"yaml\")\n\n    def test_empty_nodes_list(self):\n        \"\"\"Test loading with an empty 'nodes' list (should be valid).\"\"\"\n        yaml_content = \"\"\"\n        dag_id: \"test_empty_nodes\"\n        nodes: []\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        task_graph = self.loader.load_from_file(file_path, \"yaml\")\n        self.assertIsInstance(task_graph, TaskGraph)\n        self.assertEqual(task_graph.graph_id, \"test_empty_nodes\")\n        self.assertEqual(len(task_graph.nodes), 0)\n\n    def test_missing_node_id(self):\n        \"\"\"Test error when a node is missing 'node_id'.\"\"\"\n        yaml_content = \"\"\"\n        dag_id: \"test_missing_node_id\"\n        nodes:\n          - node_type: \"COMPUTE\"\n            dependencies: []\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        with self.assertRaisesRegex(ValueError, \"The 'node_id' is missing\"):\n            self.loader.load_from_file(file_path, \"yaml\")\n\n    def test_missing_node_type(self):\n        \"\"\"Test error when a node is missing 'node_type'.\"\"\"\n        yaml_content = \"\"\"\n        dag_id: \"test_missing_node_type\"\n        nodes:\n          - node_id: \"node1\"\n            dependencies: []\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        with self.assertRaisesRegex(ValueError, \"Node 'node1' is missing 'node_type'\"):\n            self.loader.load_from_file(file_path, \"yaml\")\n\n    def test_invalid_node_type(self):\n        \"\"\"Test error with an invalid node type string.\"\"\"\n        yaml_content = \"\"\"\n        dag_id: \"test_invalid_node_type\"\n        nodes:\n          - node_id: \"node1\"\n            node_type: \"INVALID_TYPE\"\n            dependencies: []\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        with self.assertRaisesRegex(ValueError, \"The 'node_type' .* is invalid.\"):\n            self.loader.load_from_file(file_path, \"yaml\")\n\n    def test_invalid_node_role(self):\n        \"\"\"Test error with an invalid node role string.\"\"\"\n        yaml_content = \"\"\"\n        dag_id: \"test_invalid_node_role\"\n        nodes:\n          - node_id: \"node1\"\n            node_type: \"COMPUTE\"\n            node_role: \"INVALID_ROLE\"\n            dependencies: []\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        with self.assertRaisesRegex(ValueError, \"The 'node_role' .* is invalid.\"):\n            self.loader.load_from_file(file_path, \"yaml\")\n\n    def test_duplicate_node_ids(self):\n        \"\"\"Test warning for duplicate node IDs (TaskGraph handles replacement).\"\"\"\n        yaml_content = \"\"\"\n        dag_id: \"test_duplicate_nodes\"\n        nodes:\n          - node_id: \"node1\"\n            node_type: \"COMPUTE\"\n            dependencies: []\n          - node_id: \"node1\"\n            node_type: \"DATA_LOAD\"\n            dependencies: []\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        with self.assertRaisesRegex(ValueError, \"Duplicate node ID\"):\n            task_graph = self.loader.load_from_file(file_path, \"yaml\")\n\n    def test_non_existent_dependency(self):\n        \"\"\"Test error when a dependency does not exist in the graph.\"\"\"\n        yaml_content = \"\"\"\n        dag_id: \"test_non_existent_dep\"\n        nodes:\n          - node_id: \"node1\"\n            node_type: \"COMPUTE\"\n            dependencies: [\"non_existent_node\"]\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        with self.assertRaisesRegex(ValueError, \"The dependency 'non_existent_node' of node 'node1' does not exist in the graph.\"):\n            self.loader.load_from_file(file_path, \"yaml\")\n\n    def test_circular_dependency(self):\n        \"\"\"Test error when a circular dependency is detected.\"\"\"\n        yaml_content = \"\"\"\n        dag_id: \"test_circular_dep\"\n        nodes:\n          - node_id: \"nodeA\"\n            node_type: \"COMPUTE\"\n            dependencies: [\"nodeB\"]\n          - node_id: \"nodeB\"\n            node_type: \"COMPUTE\"\n            dependencies: [\"nodeC\"]\n          - node_id: \"nodeC\"\n            node_type: \"COMPUTE\"\n            dependencies: [\"nodeA\"]\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        with self.assertRaisesRegex(ValueError, \"There are circular dependencies in graph 'test_circular_dep'\"):\n            self.loader.load_from_file(file_path, \"yaml\")\n\n    def test_executable_ref_not_found_module(self):\n        \"\"\"Test error when executable_ref points to a non-existent module.\"\"\"\n        yaml_content = \"\"\"\n        dag_id: \"test_invalid_exec_ref_module\"\n        nodes:\n          - node_id: \"node1\"\n            node_type: \"COMPUTE\"\n            executable_ref: \"non_existent_module.some_function\"\n            dependencies: []\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        with self.assertRaisesRegex(ImportError, \"Failed to load the executable function from 'non_existent_module.some_function'\"):\n            self.loader.load_from_file(file_path, \"yaml\")\n\n    def test_executable_ref_not_found_function(self):\n        \"\"\"Test error when executable_ref points to a non-existent function in a valid module.\"\"\"\n        yaml_content = f\"\"\"\n        dag_id: \"test_invalid_exec_ref_function\"\n        nodes:\n          - node_id: \"node1\"\n            node_type: \"COMPUTE\"\n            executable_ref: \"{self.dummy_tasks_file.replace(\".py\", \"\")}.non_existent_function\"\n            dependencies: []\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        with self.assertRaisesRegex(ImportError, \"Failed to load the executable function from 'sii_rl_test_tasks.non_existent_function'\"):\n            self.loader.load_from_file(file_path, \"yaml\")\n\n    def test_executable_ref_not_callable(self):\n        \"\"\"Test error when executable_ref points to an object that is not callable.\"\"\"\n        yaml_content = f\"\"\"\n        dag_id: \"test_exec_ref_not_callable\"\n        nodes:\n          - node_id: \"node1\"\n            node_type: \"COMPUTE\"\n            executable_ref: \"{self.dummy_tasks_file.replace(\".py\", \"\")}.non_callable_object\"\n            dependencies: []\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        with self.assertRaisesRegex(ImportError, \"The object resolved from 'sii_rl_test_tasks.non_callable_object' is not callable.\"):\n            self.loader.load_from_file(file_path, \"yaml\")\n\n    def test_ref_resolution(self):\n        \"\"\"Test correct resolution of !Ref tags.\"\"\"\n        yaml_content = \"\"\"\n        dag_id: \"test_ref_resolution\"\n        global_config:\n          default_batch_size: 64\n          data_source: \"s3://my-bucket/data\"\n          nested:\n            level1:\n              level2: \"deep_value\"\n        nodes:\n          - node_id: \"node1\"\n            node_type: \"DATA_LOAD\"\n            dependencies: []\n            config:\n              batch_size: !Ref global_config.default_batch_size\n              source: !Ref global_config.data_source\n              deep_param: !Ref global_config.nested.level1.level2\n            executable_ref: \"sii_rl_test_tasks.load_data_node\"\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        task_graph = self.loader.load_from_file(file_path, \"yaml\")\n        node1 = task_graph.get_node(\"node1\")\n        self.assertEqual(node1.config[\"batch_size\"], 64)\n        self.assertEqual(node1.config[\"source\"], \"s3://my-bucket/data\")\n        self.assertEqual(node1.config[\"deep_param\"], \"deep_value\")\n\n    def test_ref_resolution_invalid_path(self):\n        \"\"\"Test error when !Ref points to an invalid path.\"\"\"\n        yaml_content = \"\"\"\n        dag_id: \"test_ref_invalid_path\"\n        global_config:\n          param1: value1\n        nodes:\n          - node_id: \"node1\"\n            node_type: \"COMPUTE\"\n            dependencies: []\n            config:\n              invalid_ref: !Ref global_config.non_existent_key\n            executable_ref: \"sii_rl_test_tasks.load_data_node\"\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        with self.assertRaisesRegex(ValueError, \"Unresolved reference 'global_config.non_existent_key'.\"):\n            self.loader.load_from_file(file_path, \"yaml\")\n\n    def test_ref_resolution_invalid_path_nested(self):\n        \"\"\"Test error when !Ref points to an invalid nested path.\"\"\"\n        yaml_content = \"\"\"\n        dag_id: \"test_ref_invalid_nested_path\"\n        global_config:\n          param1:\n            sub_param: value1\n        nodes:\n          - node_id: \"node1\"\n            node_type: \"COMPUTE\"\n            dependencies: []\n            config:\n              invalid_ref: !Ref global_config.param1.non_existent_sub_key\n            executable_ref: \"sii_rl_test_tasks.load_data_node\"\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        with self.assertRaisesRegex(ValueError, \"Unresolved reference 'global_config.param1.non_existent_sub_key'.\"):\n            self.loader.load_from_file(file_path, \"yaml\")\n\n    def test_node_role_validation(self):\n        \"\"\"Test node role validation for non-model node types.\"\"\"\n        # This should pass: COMPUTE node with DEFAULT role\n        yaml_content_valid = \"\"\"\n        dag_id: \"test_role_valid\"\n        nodes:\n          - node_id: \"node1\"\n            node_type: \"COMPUTE\"\n            node_role: \"DEFAULT\"\n            dependencies: []\n        \"\"\"\n        file_path_valid = self._create_config_file(yaml_content_valid)\n        task_graph_valid = self.loader.load_from_file(file_path_valid, \"yaml\")\n        self.assertEqual(task_graph_valid.get_node(\"node1\").node_role, NodeRole.DEFAULT)\n\n        # This should fail: DATA_LOAD node with ACTOR role\n        yaml_content_invalid = \"\"\"\n        dag_id: \"test_role_invalid\"\n        nodes:\n          - node_id: \"node2\"\n            node_type: \"DATA_LOAD\"\n            node_role: \"ACTOR\"\n            dependencies: []\n        \"\"\"\n        file_path_invalid = self._create_config_file(yaml_content_invalid)\n        with self.assertRaisesRegex(ValueError, \"The role type of non-model nodes must be DEFAULT\"):\n            self.loader.load_from_file(file_path_invalid, \"yaml\")\n\n        # This should pass: MODEL_TRAIN node with ACTOR role\n        yaml_content_model_actor = \"\"\"\n        dag_id: \"test_model_actor\"\n        nodes:\n          - node_id: \"node3\"\n            node_type: \"MODEL_TRAIN\"\n            node_role: \"ACTOR\"\n            dependencies: []\n        \"\"\"\n        file_path_model_actor = self._create_config_file(yaml_content_model_actor)\n        task_graph_model_actor = self.loader.load_from_file(file_path_model_actor, \"yaml\")\n        self.assertEqual(task_graph_model_actor.get_node(\"node3\").node_role, NodeRole.ACTOR)\n\n    @patch(\"graphviz.Digraph\")\n    def test_save_dag_pic_no_nodes(self, mock_digraph):\n        \"\"\"Test save_dag_pic with no nodes in the graph.\"\"\"\n        yaml_content = \"\"\"\n        dag_id: \"test_empty_graph_pic\"\n        nodes: []\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        task_graph = self.loader.load_from_file(file_path, \"yaml\")\n        with capture_loguru_logs(level=\"WARNING\") as logs:\n            result = task_graph.save_dag_pic()\n            self.assertIn(\"DAG graph 'test_empty_graph_pic' is empty. No image will be generated.\", \"\\n\".join(logs))\n        self.assertIsNone(result)\n        mock_digraph.assert_not_called()\n\n    @patch(\"graphviz.Digraph\")\n    def test_save_dag_pic_invalid_graph(self, mock_digraph):\n        \"\"\"Test save_dag_pic with an invalid graph (circular dependency).\"\"\"\n        yaml_content = \"\"\"\n        dag_id: \"test_circular_dep_pic\"\n        nodes:\n          - node_id: \"nodeA\"\n            node_type: \"COMPUTE\"\n            dependencies: [\"nodeB\"]\n          - node_id: \"nodeB\"\n            node_type: \"COMPUTE\"\n            dependencies: [\"nodeA\"]\n        \"\"\"\n        file_path = self._create_config_file(yaml_content)\n        # We expect load_config to raise ValueError for circular dependency\n        with self.assertRaises(ValueError):\n            self.loader.load_from_file(file_path, \"yaml\")\n\n        # To test save_dag_pic on an invalid graph, we'd need to bypass load_config's validation\n        # or create an invalid graph manually. Let's create one manually.\n        graph = TaskGraph(\"manual_invalid_graph\")\n        nodeA = Node(\"nodeA\", NodeType.COMPUTE, dependencies=[\"nodeB\"])\n        nodeB = Node(\"nodeB\", NodeType.COMPUTE, dependencies=[\"nodeA\"])\n        graph.add_node(nodeA)\n        graph.add_node(nodeB)\n        graph.build_adjacency_lists()  # Ensure adj lists are built for validation\n\n        with capture_loguru_logs(level=\"ERROR\") as logs:\n            result = graph.save_dag_pic()\n            self.assertIn(\"Graph 'manual_invalid_graph' is invalid. Unable to generate image:\", \"\\n\".join(logs))\n        self.assertIsNone(result)\n        mock_digraph.assert_not_called()\n    \n    def test_load_profiler_yaml_config(self):\n        \"\"\"Test loading a valid profiler configuration.\"\"\"\n        yaml_context=\"\"\"\n        data: null\n        actor_rollout_ref: null\n        critic: null\n        reward_model: null\n        custom_reward_function: null\n        algorithm: null\n        trainer: null\n        dag: null\n        profiler:\n          enable: True\n          save_path: './prof_data'\n          level: 'level1'\n          with_memory: False\n          record_shapes: False\n          with_npu: True\n          with_cpu: False\n          with_module: False\n          with_stack: False\n          analysis: True\n          discrete: False\n          roles: ['generate', 'compute_reward']\n          all_ranks: False\n          ranks: [0]\n          profile_steps: [0]\n        \"\"\"\n        file_path = self._create_config_file(yaml_context, \"yaml\")\n        from omegaconf import OmegaConf\n        yaml_dict = OmegaConf.load(file_path)\n        profiler = parse_config(yaml_dict).profiler\n        self.assertIsInstance(profiler, ProfilerArguments)\n        self.assertTrue(profiler.enable)\n        self.assertEqual(profiler.level, \"level1\")\n        self.assertFalse(profiler.with_memory)\n        self.assertTrue(profiler.with_npu)\n\n        def is_subset(subset, superset):\n            set_subset = set(subset)\n            set_superset = set(superset)\n            return set_subset.issubset(set_superset)\n        self.assertTrue(is_subset(profiler.roles, [\"generate\", \"compute_reward\", \"compute_old_log_prob\", \n                        \"compute_ref_log_porb\", \"compute_value\", \"compute_advantage\", \"train_critic\", \"train_actor\"]))\n\n    def test_profiler_npu_environment(self):\n        \"\"\"Test npu environment for profiler.\"\"\"\n        config = ProfilerArguments(enable=True)\n        profiler = DistProfiler(rank=0, config=config)\n        from siirl.utils.extras.device import is_npu_available\n        if not is_npu_available:\n            self.assertFalse(profiler.config.enable)\n        else:\n            self.assertTrue(profiler.config.enable)\n\n\n\nif __name__ == \"__main__\":\n    unittest.main(argv=[\"first-arg-is-ignored\"], exit=False)\n"
  },
  {
    "path": "tests/dag/test_node.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport unittest\nfrom typing import Any, Dict\n\nfrom siirl.workers.dag import Node, NodeRole, NodeStatus, NodeType\n\n\n# Helper function for testing executable_ref (defined inside the test module)\ndef local_sync_executable(data: Any, node_config: Dict = None) -> Dict:\n    \"\"\"A simple local synchronous executable function for testing.\"\"\"\n    # print(f\"local_sync_executable called with data: {data}, config: {node_config}\")\n    return {\"processed_data\": data, \"config_used\": node_config}\n\n\nasync def local_async_executable(data: Any, node_config: Dict = None) -> Dict:\n    \"\"\"A simple local asynchronous executable function for testing.\"\"\"\n    # print(f\"local_async_executable called with data: {data}, config: {node_config}\")\n    await asyncio.sleep(0.01)  # Simulate an IO operation\n    return {\"processed_async_data\": data, \"config_used\": node_config}\n\n\n\"\"\"A non-callable local variable for testing error handling.\"\"\"\nlocal_not_callable = \"I am not a function, I am a result.\"\n\n\n# Create a function that always fails for testing retries\nasync def always_fail_func(**kwargs):\n    print(f\"  [Executable] always_fail_func: Executing and about to fail...\")\n    await asyncio.sleep(0.05)\n    raise RuntimeError(\"Simulated execution error\")\n\n\nclass TestNode(unittest.TestCase):\n    \"\"\"Unit tests for the Node class\"\"\"\n\n    def test_node_creation_valid(self):\n        \"\"\"Test the creation of a Node with valid parameters.\"\"\"\n        node = Node(node_id=\"n1\", node_type=NodeType.COMPUTE, node_role=NodeRole.DEFAULT, dependencies=[\"dep1\"], config={\"key\": \"value\"}, executable_ref=f\"{__name__}.local_sync_executable\", retry_limit=3)\n        self.assertEqual(node.node_id, \"n1\")\n        self.assertEqual(node.node_type, NodeType.COMPUTE)\n        self.assertEqual(node.dependencies, [\"dep1\"])\n        self.assertEqual(node.config, {\"key\": \"value\"})\n        self.assertTrue(callable(node.executable))\n        self.assertEqual(node.node_role, NodeRole.DEFAULT)\n        self.assertEqual(node.retry_limit, 3)\n        self.assertEqual(node.status, NodeStatus.PENDING)\n\n    def test_node_creation_minimal(self):\n        \"\"\"Test the creation of a Node with minimal parameters.\"\"\"\n        node = Node(node_id=\"n2\", node_type=NodeType.DATA_LOAD)\n        self.assertEqual(node.node_id, \"n2\")\n        self.assertEqual(node.node_type, NodeType.DATA_LOAD)\n        self.assertEqual(node.node_role, NodeRole.DEFAULT)\n        self.assertEqual(node.dependencies, [])\n        self.assertEqual(node.config, {})\n        self.assertIsNone(node.executable)\n        self.assertEqual(node.retry_limit, 0)\n\n    def test_node_creation_invalid_id(self):\n        \"\"\"Test the creation of a Node with an invalid node_id (empty string).\"\"\"\n        with self.assertRaisesRegex(ValueError, \"node_id must be a non-empty string\"):\n            Node(node_id=\"\", node_type=NodeType.COMPUTE)\n\n    def test_node_creation_invalid_type(self):\n        \"\"\"Test the creation of a Node with an invalid node_type.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"node_type must be a member of the NodeType enum\"):\n            Node(node_id=\"n3\", node_type=\"INVALID_TYPE\")  # type: ignore\n\n    def test_resolve_executable_non_existent(self):\n        \"\"\"Test resolving a non-existent executable function reference.\"\"\"\n        with self.assertRaisesRegex(ImportError, \"Failed to load the executable function from 'non_existent_module.non_existent_func'\"):\n            Node(node_id=\"n4\", node_type=NodeType.COMPUTE, executable_ref=\"non_existent_module.non_existent_func\")\n\n    def test_resolve_executable_not_callable(self):\n        \"\"\"Test resolving a non-callable executable function reference.\"\"\"\n        with self.assertRaisesRegex(ImportError, f\".*The object resolved from .* is not callable.\"):\n            Node(node_id=\"n5\", node_type=NodeType.COMPUTE, executable_ref=f\"{__name__}.local_not_callable\")\n\n    def test_add_remove_dependency(self):\n        \"\"\"Test adding and removing dependencies.\"\"\"\n        node = Node(node_id=\"n6\", node_type=NodeType.COMPUTE)\n        node.add_dependency(\"dep1\")\n        self.assertIn(\"dep1\", node.dependencies)\n        node.add_dependency(\"dep2\")\n        self.assertIn(\"dep2\", node.dependencies)\n        node.add_dependency(\"dep1\")  # Adding a duplicate dependency should have no effect\n        self.assertEqual(node.dependencies.count(\"dep1\"), 1)\n        node.remove_dependency(\"dep1\")\n        self.assertNotIn(\"dep1\", node.dependencies)\n        node.remove_dependency(\"non_existent_dep\")  # Removing a non-existent dependency should have no effect\n        self.assertEqual(node.dependencies, [\"dep2\"])\n\n    def test_is_ready(self):\n        \"\"\"Test the logic of the is_ready method.\"\"\"\n        node_no_deps = Node(node_id=\"n_no_deps\", node_type=NodeType.COMPUTE)\n        node_with_deps = Node(node_id=\"n_with_deps\", node_type=NodeType.COMPUTE, dependencies=[\"d1\", \"d2\"])\n\n        self.assertTrue(node_no_deps.is_ready(set()))\n        self.assertTrue(node_no_deps.is_ready({\"d1\"}))  # Unrelated completed nodes\n\n        self.assertFalse(node_with_deps.is_ready(set()))\n        self.assertFalse(node_with_deps.is_ready({\"d1\"}))\n        self.assertTrue(node_with_deps.is_ready({\"d1\", \"d2\"}))\n        self.assertTrue(node_with_deps.is_ready({\"d1\", \"d2\", \"d3\"}))\n\n        node_with_deps.update_status(NodeStatus.RUNNING)\n        self.assertFalse(node_with_deps.is_ready({\"d1\", \"d2\"}), \"Nodes not in PENDING status should not be ready\")\n\n    def test_update_status(self):\n        \"\"\"Test updating the node status.\"\"\"\n        node = Node(node_id=\"n7\", node_type=NodeType.COMPUTE)\n        self.assertEqual(node.status, NodeStatus.PENDING)\n        node.update_status(NodeStatus.COMPLETED)\n        self.assertEqual(node.status, NodeStatus.COMPLETED)\n        self.assertIsNone(node.error_info)\n\n        node.update_status(NodeStatus.FAILED, \"It failed\")\n        self.assertEqual(node.status, NodeStatus.FAILED)\n        self.assertEqual(node.error_info, \"It failed\")\n\n        # error_info should be cleared upon successful completion\n        node.update_status(NodeStatus.COMPLETED)\n        self.assertIsNone(node.error_info)\n\n    def test_retry_logic(self):\n        \"\"\"Test retry-related logic.\"\"\"\n        node = Node(node_id=\"n8\", node_type=NodeType.COMPUTE, retry_limit=2)\n        self.assertFalse(node.can_retry())  # Initial state is PENDING\n        node.update_status(NodeStatus.FAILED)\n        self.assertTrue(node.can_retry())\n        node.increment_retry_count()\n        self.assertEqual(node.retries_done, 1)\n        self.assertTrue(node.can_retry())\n        node.increment_retry_count()\n        self.assertEqual(node.retries_done, 2)\n        self.assertFalse(node.can_retry())  # Retry limit reached\n\n    def test_execute_no_executable(self):\n        \"\"\"Test executing a node without an executable function.\"\"\"\n        node = Node(node_id=\"n9\", node_type=NodeType.BARRIER_SYNC)\n        output = asyncio.run(node.run())\n        self.assertIsNone(output)\n        self.assertEqual(node.status, NodeStatus.RUNNING)  # run sets the status to RUNNING internally\n        # If there is no executable function, it will not automatically change to COMPLETED\n        # In the original code, it returns directly when there is no executable, and the status remains RUNNING\n        # The caller may need to update the status later\n        # For testing purposes, we check if it is None and the status is RUNNING\n        # The actual scheduler should handle this situation\n\n    def test_execute_sync_executable(self):\n        \"\"\"Test executing a synchronous executable function.\"\"\"\n        node_id = \"n10_sync\"\n        node_config = {\"multiplier\": 3}\n        # Note: The executable_ref here must point to a function in the test file or an imported function\n        node = Node(node_id=node_id, node_type=NodeType.COMPUTE, executable_ref=f\"{__name__}.local_sync_executable\", config=node_config)\n\n        # Node.run is an async def, so even if it calls a synchronous function internally, it still needs to be awaited\n        # kwargs simulate the output from dependent nodes\n        result_coro = node.run(data=5)  # run returns a coroutine\n        result = asyncio.run(result_coro)  # Run this coroutine\n\n        self.assertEqual(result, {\"processed_data\": 5, \"config_used\": node_config})\n        self.assertEqual(node.output, {\"processed_data\": 5, \"config_used\": node_config})\n        self.assertEqual(node.status, NodeStatus.COMPLETED)\n\n    def test_execute_async_executable(self):\n        \"\"\"Test executing an asynchronous executable function.\"\"\"\n        node_id = \"n11_async\"\n        node_config = {\"adder\": 10}\n        node = Node(node_id=node_id, node_type=NodeType.COMPUTE, executable_ref=f\"{__name__}.local_async_executable\", config=node_config)\n\n        result_coro = node.run(data=7)  # run returns a coroutine\n        result = asyncio.run(result_coro)  # Run this coroutine\n\n        self.assertEqual(result, {\"processed_async_data\": 7, \"config_used\": node_config})\n        self.assertEqual(node.output, {\"processed_async_data\": 7, \"config_used\": node_config})\n        self.assertEqual(node.status, NodeStatus.COMPLETED)\n\n    def test_execute_failing_executable(self):\n        \"\"\"Test executing a failing executable function.\"\"\"\n        node = Node(node_id=\"n12_fail\", node_type=NodeType.COMPUTE, executable_ref=f\"{__name__}.always_fail_func\")\n\n        with self.assertRaisesRegex(RuntimeError, \"An error occurred while executing node n12_fail: Simulated execution error\"):\n            asyncio.run(node.run())  # The coroutine returned by run will raise an exception when awaited\n\n        self.assertEqual(node.status, NodeStatus.FAILED)\n        self.assertIn(\"Simulated execution error\", node.error_info)\n\n\nif __name__ == \"__main__\":\n    # Ensure that executable_ref in the test can correctly resolve functions in this file\n    # If the test file is named test_my_module.py, __name__ will be 'test_my_module'\n    # If you run python test_my_module.py directly, __name__ will be '__main__'\n    # To ensure consistency, you can set a global variable or dynamically obtain the module name in the test\n    # But usually the unittest test loader will handle the module import correctly, so __name__ should point to the test module name\n\n    # print(f\"Running tests with __name__ = {__name__}\")\n    # print(f\"Attempting to use executable_ref like: {__name__}.local_sync_executable\")\n\n    unittest.main(verbosity=2)\n"
  },
  {
    "path": "tests/dag/test_task_graph.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\n\nfrom siirl.workers.dag import Node, NodeRole, NodeStatus, NodeType, TaskGraph\n\n\ndef example_data_load_func():\n    pass\n\n\ndef example_compute_func():\n    pass\n\n\nclass TestTaskGraph(unittest.TestCase):\n    \"\"\"Unit tests for the TaskGraph class\"\"\"\n\n    def setUp(self):\n        self.graph = TaskGraph(graph_id=\"test_graph\")\n        self.node_a = Node(node_id=\"A\", node_type=NodeType.DATA_LOAD)\n        self.node_b = Node(node_id=\"B\", node_type=NodeType.COMPUTE, node_role=NodeRole.ROLLOUT, dependencies=[\"A\"])\n        self.node_c = Node(node_id=\"C\", node_type=NodeType.COMPUTE, node_role=NodeRole.REFERENCE, dependencies=[\"A\"])\n        self.node_d = Node(node_id=\"D\", node_type=NodeType.BARRIER_SYNC, dependencies=[\"B\", \"C\"])\n        self.node_e = Node(node_id=\"E\", node_type=NodeType.MODEL_TRAIN, node_role=NodeRole.ACTOR, dependencies=[\"D\"])\n\n        self.dag_module_sync_ref = f\"{__name__}.example_data_load_func\"\n        self.dag_module_async_ref = f\"{__name__}.example_compute_func\"\n\n    def test_graph_creation(self):\n        self.assertEqual(self.graph.graph_id, \"test_graph\")\n        self.assertEqual(self.graph.nodes, {})\n\n    def test_add_node(self):\n        self.graph.add_node(self.node_a)\n        self.assertIn(\"A\", self.graph.nodes)\n        self.assertEqual(self.graph.nodes[\"A\"], self.node_a)\n        new_node_a = Node(node_id=\"A\", node_type=NodeType.CUSTOM)\n        self.graph.add_node(new_node_a)\n        self.assertEqual(self.graph.nodes[\"A\"].node_type, NodeType.CUSTOM)\n        with self.assertRaisesRegex(ValueError, \"Only Node type objects can be added to the graph\"):\n            self.graph.add_node(\"not_a_node\")  # type: ignore\n\n    def test_build_adjacency_lists(self):\n        self.graph.add_node(self.node_a)\n        self.graph.add_node(self.node_b)\n        self.graph.add_node(self.node_c)\n        self.graph.add_node(self.node_d)\n        self.graph.build_adjacency_lists()\n        self.assertEqual(set(self.graph.adj.get(\"A\", [])), {\"B\", \"C\"})\n        self.assertEqual(set(self.graph.adj.get(\"B\", [])), {\"D\"})\n        self.assertEqual(set(self.graph.adj.get(\"C\", [])), {\"D\"})\n        self.assertEqual(self.graph.adj.get(\"D\", []), [])\n        self.assertEqual(set(self.graph.rev_adj.get(\"A\", [])), set())\n        self.assertEqual(set(self.graph.rev_adj.get(\"B\", [])), {\"A\"})\n        self.assertEqual(set(self.graph.rev_adj.get(\"C\", [])), {\"A\"})\n        self.assertEqual(set(self.graph.rev_adj.get(\"D\", [])), {\"B\", \"C\"})\n        graph2 = TaskGraph(\"graph2\")\n        node_x = Node(\"X\", NodeType.COMPUTE)\n        node_y = Node(\"Y\", NodeType.COMPUTE, dependencies=[\"X\"])\n        graph2.add_node(node_x)\n        graph2.add_node(node_y)\n        self.assertEqual(set(graph2.adj.get(\"X\", [])), {\"Y\"})  # Relies on _update_adj_for_node\n        self.assertEqual(set(graph2.rev_adj.get(\"Y\", [])), {\"X\"})  # Relies on _update_adj_for_node\n\n    def test_get_node(self):\n        self.graph.add_node(self.node_a)\n        self.assertEqual(self.graph.get_node(\"A\"), self.node_a)\n        self.assertIsNone(self.graph.get_node(\"X\"))\n\n    def test_get_dependencies_and_dependents(self):\n        self.graph.add_node(self.node_a)\n        self.graph.add_node(self.node_b)\n        self.graph.add_node(self.node_c)\n        self.graph.build_adjacency_lists()\n        deps_b = [n.node_id for n in self.graph.get_dependencies(\"B\")]\n        self.assertEqual(deps_b, [\"A\"])\n        dependents_a = sorted([n.node_id for n in self.graph.get_dependents(\"A\")])\n        self.assertEqual(dependents_a, [\"B\", \"C\"])\n        self.assertEqual(self.graph.get_dependencies(\"A\"), [])\n        self.assertEqual(self.graph.get_dependents(\"C\"), [])\n\n    def test_get_entry_and_exit_nodes(self):\n        self.graph.add_node(self.node_a)\n        self.graph.add_node(self.node_b)\n        self.graph.add_node(self.node_c)\n        self.graph.add_node(self.node_d)\n        self.graph.add_node(self.node_e)\n        self.graph.build_adjacency_lists()\n        entry_nodes = sorted([n.node_id for n in self.graph.get_entry_nodes()])\n        self.assertEqual(entry_nodes, [\"A\"])\n        exit_nodes = sorted([n.node_id for n in self.graph.get_exit_nodes()])\n        self.assertEqual(exit_nodes, [\"E\"])\n\n    def test_validate_graph_valid(self):\n        self.graph.add_node(self.node_a)\n        self.graph.add_node(self.node_b)\n        self.graph.build_adjacency_lists()\n        is_valid, msg = self.graph.validate_graph()\n        self.assertTrue(is_valid)\n        self.assertIsNone(msg)\n\n    def test_validate_graph_missing_dependency(self):\n        node_x = Node(\"X\", NodeType.COMPUTE, dependencies=[\"Y_non_existent\"])\n        self.graph.add_node(node_x)\n        self.graph.build_adjacency_lists()\n        is_valid, msg = self.graph.validate_graph()\n        self.assertFalse(is_valid)\n        self.assertIn(\"The dependency 'Y_non_existent' of node 'X' does not exist in the graph\", msg)\n\n    def test_validate_graph_cyclic(self):\n        node_x = Node(\"X_cyclic\", NodeType.COMPUTE, dependencies=[\"Y_cyclic\"])\n        node_y = Node(\"Y_cyclic\", NodeType.COMPUTE, dependencies=[\"X_cyclic\"])\n        self.graph.add_node(node_x)\n        self.graph.add_node(node_y)\n        self.graph.build_adjacency_lists()\n        is_valid, msg = self.graph.validate_graph()\n        self.assertFalse(is_valid)\n        self.assertIn(\"There are circular dependencies in\", msg)\n\n    def test_get_topological_sort_valid(self):\n        self.graph.add_node(self.node_a)\n        self.graph.add_node(self.node_b)\n        self.graph.add_node(self.node_c)\n        self.graph.add_node(self.node_d)\n        self.graph.build_adjacency_lists()\n        order = self.graph.get_topological_sort()\n        self.assertEqual(len(order), 4)\n        self.assertEqual(set(order), {\"A\", \"B\", \"C\", \"D\"})\n        self.assertLess(order.index(\"A\"), order.index(\"B\"))\n        self.assertLess(order.index(\"A\"), order.index(\"C\"))\n        if \"B\" in order and \"D\" in order:\n            self.assertLess(order.index(\"B\"), order.index(\"D\"))\n        if \"C\" in order and \"D\" in order:\n            self.assertLess(order.index(\"C\"), order.index(\"D\"))\n\n    def test_get_topological_sort_empty_graph(self):\n        self.assertEqual(self.graph.get_topological_sort(), [])\n\n    def test_get_topological_sort_cyclic(self):\n        node_x = Node(\"X_cyclic_topo\", NodeType.COMPUTE, dependencies=[\"Y_cyclic_topo\"])\n        node_y = Node(\"Y_cyclic_topo\", NodeType.COMPUTE, dependencies=[\"X_cyclic_topo\"])\n        self.graph.add_node(node_x)\n        self.graph.add_node(node_y)\n        self.graph.build_adjacency_lists()\n        with self.assertRaisesRegex(ValueError, \"There are circular dependencies\"):\n            self.graph.get_topological_sort()\n\n    def test_reset_nodes_status(self):\n        self.node_a.update_status(NodeStatus.COMPLETED, \"Done A\")\n        self.node_a.output = \"Output A\"\n        self.node_b.update_status(NodeStatus.FAILED, \"Failed B\")\n        self.node_b.retries_done = 1\n        self.graph.add_node(self.node_a)\n        self.graph.add_node(self.node_b)\n        self.graph.reset_nodes_status()\n        self.assertEqual(self.node_a.status, NodeStatus.PENDING)\n        self.assertIsNone(self.node_a.output)\n        self.assertIsNone(self.node_a.error_info)\n        self.assertEqual(self.node_a.retries_done, 0)\n        self.assertEqual(self.node_b.status, NodeStatus.PENDING)\n        self.assertIsNone(self.node_b.output)\n        self.assertIsNone(self.node_b.error_info)\n        self.assertEqual(self.node_b.retries_done, 0)\n\n    def test_load_from_config_valid(self):\n        graph_config = [\n            {\"node_id\": \"cfg_A\", \"node_type\": \"DATA_LOAD\", \"executable_ref\": self.dag_module_sync_ref, \"config\": {\"path\": \"dummy.csv\"}},\n            {\"node_id\": \"cfg_B\", \"node_type\": \"COMPUTE\", \"dependencies\": [\"cfg_A\"], \"executable_ref\": self.dag_module_async_ref, \"config\": {\"operation\": \"sum\"}, \"node_role\": \"ACTOR\", \"retry_limit\": 1},\n            {\"node_id\": \"cfg_C\", \"node_type\": \"BARRIER_SYNC\", \"dependencies\": [\"cfg_B\"]},\n        ]\n        graph = TaskGraph.load_from_config(\"config_graph_1\", graph_config)\n        self.assertEqual(graph.graph_id, \"config_graph_1\")\n        self.assertEqual(len(graph.nodes), 3)\n        self.assertIn(\"cfg_A\", graph.nodes)\n        self.assertEqual(graph.nodes[\"cfg_A\"].node_type, NodeType.DATA_LOAD)\n        self.assertEqual(graph.nodes[\"cfg_B\"].node_role, NodeRole.ACTOR)\n        self.assertEqual(graph.nodes[\"cfg_B\"].retry_limit, 1)\n        self.assertTrue(callable(graph.nodes[\"cfg_A\"].executable))\n        order = graph.get_topological_sort()\n        self.assertEqual(order, [\"cfg_A\", \"cfg_B\", \"cfg_C\"])\n\n    def test_load_from_config_missing_field(self):\n        graph_config_no_id = [{\"node_type\": \"DATA_LOAD\"}]\n        with self.assertRaisesRegex(ValueError, \".*missing required field: 'node_id'.*\"):\n            TaskGraph.load_from_config(\"bad_cfg_no_id\", graph_config_no_id)\n        graph_config_no_type = [{\"node_id\": \"X_no_type\"}]\n        with self.assertRaisesRegex(ValueError, \".*missing 'node_type'.*\"):\n            TaskGraph.load_from_config(\"bad_cfg_no_type\", graph_config_no_type)\n\n    def test_load_from_config_invalid_enum_value(self):\n        graph_config_invalid_type = [{\"node_id\": \"X_invalid_type\", \"node_type\": \"INVALID_NODE_TYPE_VALUE\"}]\n        with self.assertRaisesRegex(ValueError, \".*INVALID_NODE_TYPE_VALUE.*\"):\n            TaskGraph.load_from_config(\"bad_cfg_invalid_type\", graph_config_invalid_type)\n        graph_config_invalid_role = [{\"node_id\": \"Y_invalid_role\", \"node_type\": \"COMPUTE\", \"node_role\": \"INVALID_ROLE_VALUE\"}]\n        with self.assertRaisesRegex(ValueError, \".*INVALID_ROLE_VALUE.*\"):\n            TaskGraph.load_from_config(\"bad_cfg_invalid_role\", graph_config_invalid_role)\n\n    def test_load_from_config_invalid_graph_structure(self):\n        graph_config_cyclic = [\n            {\"node_id\": \"X_cfg_cyclic\", \"node_type\": \"COMPUTE\", \"dependencies\": [\"Y_cfg_cyclic\"]},\n            {\"node_id\": \"Y_cfg_cyclic\", \"node_type\": \"COMPUTE\", \"dependencies\": [\"X_cfg_cyclic\"]},\n        ]\n        with self.assertRaisesRegex(ValueError, \".*configuration is invalid:.*There are circular dependencies.*\"):\n            TaskGraph.load_from_config(\"cyclic_cfg_from_config\", graph_config_cyclic)\n\n\nif __name__ == \"__main__\":\n    unittest.main(verbosity=2)\n"
  },
  {
    "path": "tests/dag/test_task_loader.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\nimport copy\nfrom typing import List, Set, Dict, Tuple\n\n# Assuming node.py, task_graph.py, task_loader.py are accessible\nfrom siirl.workers.dag import Node, NodeType, TaskGraph\nfrom siirl.workers.dag.task_loader import generate_structural_signature, get_all_downstream_nodes_recursive, get_all_ancestors, find_all_paths, split_single_structure, split_by_fan_out_to_exits, split_by_reconverging_paths, discover_and_split_parallel_paths\n\n\n# Helper to get a set of structural signatures for a list of graphs\ndef get_signatures(graphs: List[TaskGraph]) -> List[str]:\n    \"\"\"Generates a sorted list of structural signatures for a list of graphs.\"\"\"\n    return sorted([generate_structural_signature(g) for g in graphs])\n\n\nclass TestTaskLoaderInternals(unittest.TestCase):\n    \"\"\"Tests for internal helper functions in task_loader.py.\"\"\"\n\n    def setUp(self):\n        # Simple graph: A -> B -> C, A -> D\n        #      A\n        #     / \\\n        #    B   D\n        #    |\n        #    C\n        self.node_a = Node(\"A\", NodeType.DATA_LOAD)\n        self.node_b = Node(\"B\", NodeType.COMPUTE, dependencies=[\"A\"])\n        self.node_c = Node(\"C\", NodeType.COMPUTE, dependencies=[\"B\"])\n        self.node_d = Node(\"D\", NodeType.COMPUTE, dependencies=[\"A\"])\n        self.graph1 = TaskGraph(\"g1\")\n        self.graph1.add_nodes([self.node_a, self.node_b, self.node_c, self.node_d])\n        self.graph1.build_adjacency_lists()  # Crucial for many helpers\n\n        # Linear graph: L1 -> L2 -> L3\n        self.ln1 = Node(\"L1\", NodeType.DATA_LOAD)\n        self.ln2 = Node(\"L2\", NodeType.COMPUTE, dependencies=[\"L1\"])\n        self.ln3 = Node(\"L3\", NodeType.MODEL_TRAIN, dependencies=[\"L2\"])\n        self.linear_graph = TaskGraph(\"linear\")\n        self.linear_graph.add_nodes([self.ln1, self.ln2, self.ln3])\n        self.linear_graph.build_adjacency_lists()\n\n        # Empty graph\n        self.empty_graph = TaskGraph(\"empty\")\n\n    def test_generate_structural_signature(self):\n        sig1 = generate_structural_signature(self.graph1)\n\n        graph1_reordered_nodes = TaskGraph(\"g1_reordered\")\n        # Add nodes in different order but same structure\n        graph1_reordered_nodes.add_node(copy.deepcopy(self.node_d))\n        graph1_reordered_nodes.add_node(copy.deepcopy(self.node_c))\n        graph1_reordered_nodes.add_node(copy.deepcopy(self.node_b))\n        graph1_reordered_nodes.add_node(copy.deepcopy(self.node_a))\n        sig2 = generate_structural_signature(graph1_reordered_nodes)\n        self.assertEqual(sig1, sig2, \"Signatures should match for structurally identical graphs regardless of node add order.\")\n\n        sig_empty = generate_structural_signature(self.empty_graph)\n        self.assertTrue(\"empty_structure\" in sig_empty)\n\n        sig_linear = generate_structural_signature(self.linear_graph)\n        self.assertTrue(\"L1\" in sig_linear and \"L2\" in sig_linear and \"L3\" in sig_linear)\n        self.assertTrue(\"e(L1->L2)\" in sig_linear)\n\n    def test_get_all_downstream_nodes_recursive(self):\n        # For self.graph1 (A -> B -> C, A -> D)\n        # Downstream of A should be A, B, C, D\n        downstream_a = get_all_downstream_nodes_recursive(self.graph1, \"A\")\n        self.assertSetEqual(downstream_a, {\"A\", \"B\", \"C\", \"D\"})\n\n        # Downstream of B should be B, C\n        downstream_b = get_all_downstream_nodes_recursive(self.graph1, \"B\")\n        self.assertSetEqual(downstream_b, {\"B\", \"C\"})\n\n        # Downstream of C should be C\n        downstream_c = get_all_downstream_nodes_recursive(self.graph1, \"C\")\n        self.assertSetEqual(downstream_c, {\"C\"})\n\n        # Downstream of D should be D\n        downstream_d = get_all_downstream_nodes_recursive(self.graph1, \"D\")\n        self.assertSetEqual(downstream_d, {\"D\"})\n\n        # Non-existent node\n        downstream_non_existent = get_all_downstream_nodes_recursive(self.graph1, \"Z\")\n        self.assertSetEqual(downstream_non_existent, set())\n\n        # Empty graph\n        downstream_empty = get_all_downstream_nodes_recursive(self.empty_graph, \"A\")\n        self.assertSetEqual(downstream_empty, set())\n\n    def test_get_all_ancestors(self):\n        # For self.graph1 (A -> B -> C, A -> D)\n        ancestors_c = get_all_ancestors(self.graph1, \"C\")  # Ancestors of C are B, A\n        self.assertSetEqual(ancestors_c, {\"A\", \"B\"})\n\n        ancestors_b = get_all_ancestors(self.graph1, \"B\")  # Ancestors of B is A\n        self.assertSetEqual(ancestors_b, {\"A\"})\n\n        ancestors_d = get_all_ancestors(self.graph1, \"D\")  # Ancestors of D is A\n        self.assertSetEqual(ancestors_d, {\"A\"})\n\n        ancestors_a = get_all_ancestors(self.graph1, \"A\")  # A has no ancestors\n        self.assertSetEqual(ancestors_a, set())\n\n        # Non-existent node\n        ancestors_non_existent = get_all_ancestors(self.graph1, \"Z\")\n        self.assertSetEqual(ancestors_non_existent, set())\n\n        # Empty graph\n        ancestors_empty = get_all_ancestors(self.empty_graph, \"A\")\n        self.assertSetEqual(ancestors_empty, set())\n\n    def test_find_all_paths(self):\n        # For self.graph1 (A -> B -> C, A -> D)\n        paths_a_c = find_all_paths(self.graph1, \"A\", \"C\")\n        self.assertListEqual(paths_a_c, [[\"A\", \"B\", \"C\"]])\n\n        paths_a_d = find_all_paths(self.graph1, \"A\", \"D\")\n        self.assertListEqual(paths_a_d, [[\"A\", \"D\"]])\n\n        paths_a_b = find_all_paths(self.graph1, \"A\", \"B\")\n        self.assertListEqual(paths_a_b, [[\"A\", \"B\"]])\n\n        paths_b_d = find_all_paths(self.graph1, \"B\", \"D\")  # No path\n        self.assertListEqual(paths_b_d, [])\n\n        paths_a_a = find_all_paths(self.graph1, \"A\", \"A\")  # Path to self\n        self.assertListEqual(paths_a_a, [[\"A\"]])\n\n        # Diamond graph: S -> T1 -> E, S -> T2 -> E\n        s, t1, t2, e = (Node(\"S\", NodeType.COMPUTE), Node(\"T1\", NodeType.COMPUTE, dependencies=[\"S\"]), Node(\"T2\", NodeType.COMPUTE, dependencies=[\"S\"]), Node(\"E\", NodeType.COMPUTE, dependencies=[\"T1\", \"T2\"]))\n        diamond_graph = TaskGraph(\"diamond\")\n        diamond_graph.add_nodes([s, t1, t2, e])\n        paths_s_e = find_all_paths(diamond_graph, \"S\", \"E\")\n        self.assertEqual(len(paths_s_e), 2)\n        self.assertIn([\"S\", \"T1\", \"E\"], paths_s_e)\n        self.assertIn([\"S\", \"T2\", \"E\"], paths_s_e)\n\n        # Non-existent start/end\n        self.assertListEqual(find_all_paths(self.graph1, \"Z\", \"C\"), [])\n        self.assertListEqual(find_all_paths(self.graph1, \"A\", \"Z\"), [])\n\n\nclass TestTaskSplittingLogic(unittest.TestCase):\n    \"\"\"Tests for individual splitting strategy functions.\"\"\"\n\n    def assertGraphStructure(self, graph: TaskGraph, expected_node_ids: Set[str], expected_dependencies: Dict[str, List[str]], msg_prefix=\"\"):\n        self.assertSetEqual(set(graph.nodes.keys()), expected_node_ids, f\"{msg_prefix} Node ID mismatch\")\n        graph.build_adjacency_lists()  # Ensure rev_adj is up-to-date\n        for node_id, deps in expected_dependencies.items():\n            self.assertIn(node_id, graph.rev_adj, f\"{msg_prefix} Node {node_id} not in rev_adj\")\n            self.assertListEqual(sorted(graph.rev_adj[node_id]), sorted(deps), f\"{msg_prefix} Dependencies for {node_id} mismatch\")\n\n    def test_split_single_structure_reconvergence(self):\n        # Graph: A->B, F->C, then B,C -> D -> E (reconverge at D)\n        a = Node(\"A\", NodeType.DATA_LOAD)\n        b = Node(\"B\", NodeType.COMPUTE, dependencies=[\"A\"])\n        f = Node(\"F\", NodeType.DATA_LOAD)\n        c = Node(\"C\", NodeType.COMPUTE, dependencies=[\"F\"])\n        d = Node(\"D\", NodeType.COMPUTE, dependencies=[\"B\", \"C\"])\n        e_node = Node(\"E\", NodeType.MODEL_TRAIN, dependencies=[\"D\"])\n\n        src_graph = TaskGraph(\"reconverge_src\")\n        src_graph.add_nodes([a, b, f, c, d, e_node])\n\n        parallel_branches = [[\"A\", \"B\"], [\"F\", \"C\"]]  # Node lists for branches up to merge point\n        merge_node_id = \"D\"\n\n        subgraphes = split_single_structure(src_graph, parallel_branches, merge_node_id, \"base_idx\")\n        self.assertEqual(len(subgraphes), 2)\n\n        # Expected subgraph 1: A->B->D->E\n        # Expected subgraph 2: F->C->D->E\n        subgraph1 = TaskGraph(\"exp1\")\n        subgraph1.add_nodes(\n            [\n                copy.deepcopy(a),\n                copy.deepcopy(b),\n                Node(\"D\", NodeType.COMPUTE, dependencies=[\"B\"]),  # D's deps adjusted\n                Node(\"E\", NodeType.MODEL_TRAIN, dependencies=[\"D\"]),\n            ]\n        )\n        subgraph2 = TaskGraph(\"exp2\")\n        subgraph2.add_nodes(\n            [\n                copy.deepcopy(f),\n                copy.deepcopy(c),\n                Node(\"D\", NodeType.COMPUTE, dependencies=[\"C\"]),  # D's deps adjusted\n                Node(\"E\", NodeType.MODEL_TRAIN, dependencies=[\"D\"]),\n            ]\n        )\n        expected_sig_1 = generate_structural_signature(subgraph1)\n        expected_sig_2 = generate_structural_signature(subgraph2)\n\n        actual_sigs = get_signatures(subgraphes)\n        self.assertIn(expected_sig_1, actual_sigs)\n        self.assertIn(expected_sig_2, actual_sigs)\n\n    def test_split_by_fan_out_to_exits(self):\n        # Graph: S -> A -> E1 (exit1)\n        #        S -> B -> E2 (exit2)\n        s = Node(\"S\", NodeType.DATA_LOAD)\n        a = Node(\"A\", NodeType.COMPUTE, dependencies=[\"S\"])\n        e1 = Node(\"E1\", NodeType.MODEL_TRAIN, dependencies=[\"A\"])  # Exit 1\n        b = Node(\"B\", NodeType.COMPUTE, dependencies=[\"S\"])\n        e2 = Node(\"E2\", NodeType.MODEL_TRAIN, dependencies=[\"B\"])  # Exit 2\n\n        src_graph = TaskGraph(\"fanout_src\")\n        src_graph.add_nodes([s, a, e1, b, e2])\n\n        subgraphs = split_by_fan_out_to_exits(src_graph, 1)\n        self.assertEqual(len(subgraphs), 2, \"Should split into two subgraphs for distinct exits.\")\n\n        # Expected subgraph 1: S->A->E1\n        # Expected subgraph 2: S->B->E2\n        # Need to reconstruct expected TaskGraphs to get their signatures\n\n        # Graph S->A->E1\n        exp_g1 = TaskGraph(\"exp_g1\")\n        exp_g1.add_nodes([copy.deepcopy(s), Node(\"A\", NodeType.COMPUTE, dependencies=[\"S\"]), Node(\"E1\", NodeType.MODEL_TRAIN, dependencies=[\"A\"])])\n        exp_g1_sig = generate_structural_signature(exp_g1)\n\n        # Graph S->B->E2\n        exp_g2 = TaskGraph(\"exp_g2\")\n        exp_g2.add_nodes([copy.deepcopy(s), Node(\"B\", NodeType.COMPUTE, dependencies=[\"S\"]), Node(\"E2\", NodeType.MODEL_TRAIN, dependencies=[\"B\"])])\n        exp_g2_sig = generate_structural_signature(exp_g2)\n\n        actual_sigs = get_signatures(subgraphs)\n        self.assertIn(exp_g1_sig, actual_sigs)\n        self.assertIn(exp_g2_sig, actual_sigs)\n\n    def test_split_by_reconverging_paths_diamond(self):\n        # Diamond: A -> B \\\n        #               -> D\n        #          A -> C /\n        a_ = Node(\"A\", NodeType.DATA_LOAD)\n        b_ = Node(\"B\", NodeType.COMPUTE, dependencies=[\"A\"])\n        c_ = Node(\"C\", NodeType.COMPUTE, dependencies=[\"A\"])\n        d_ = Node(\"D\", NodeType.COMPUTE, dependencies=[\"B\", \"C\"])\n        src_graph = TaskGraph(\"diamond_src\")\n        src_graph.add_nodes([a_, b_, c_, d_])\n\n        subgraphs = split_by_reconverging_paths(src_graph, 1)\n        self.assertEqual(len(subgraphs), 2, \"Diamond graph should split into two paths.\")\n\n        # Expected: A->B->D and A->C->D\n        g_abd = TaskGraph(\"exp_abd\")\n        g_abd.add_nodes([copy.deepcopy(a_), Node(\"B\", NodeType.COMPUTE, dependencies=[\"A\"]), Node(\"D\", NodeType.COMPUTE, dependencies=[\"B\"])])\n        g_acd = TaskGraph(\"exp_acd\")\n        g_acd.add_nodes([copy.deepcopy(a_), Node(\"C\", NodeType.COMPUTE, dependencies=[\"A\"]), Node(\"D\", NodeType.COMPUTE, dependencies=[\"C\"])])\n\n        actual_sigs = get_signatures(subgraphs)\n        self.assertIn(generate_structural_signature(g_abd), actual_sigs)\n        self.assertIn(generate_structural_signature(g_acd), actual_sigs)\n\n\n# Helper to extend TaskGraph for easier test setup\ndef add_nodes_and_build(self, nodes: List[Node]) -> TaskGraph:\n    self.add_nodes(nodes)\n    self.build_adjacency_lists()\n    return self\n\n\nTaskGraph.add_nodes_and_build = add_nodes_and_build\n\n\nclass TestDiscoverAndSplitParallelPaths(unittest.TestCase):\n    \"\"\"Comprehensive tests for the main discover_and_split_parallel_paths function.\"\"\"\n\n    def assertListOfGraphStructuresEqual(self, actual_graphs: List[TaskGraph], expected_graph_defs: List[Tuple[str, List[Node]]], msg=None):\n        \"\"\"\n        Compares a list of actual TaskGraphs with a list of expected graph definitions.\n        A graph definition is (name_for_debug, list_of_nodes_for_expected_graph).\n        \"\"\"\n        self.assertEqual(len(actual_graphs), len(expected_graph_defs), f\"{msg}: Different number of graphs. Got {len(actual_graphs)}, expected {len(expected_graph_defs)}\")\n\n        expected_signatures = []\n        for name, nodes in expected_graph_defs:\n            g = TaskGraph(name)\n            g.add_nodes(nodes)\n            expected_signatures.append(generate_structural_signature(g))\n\n        actual_signatures = get_signatures(actual_graphs)\n        self.assertListEqual(sorted(actual_signatures), sorted(expected_signatures), f\"{msg}: Graph structures differ.\")\n\n    def test_empty_graph(self):\n        empty_g = TaskGraph(\"empty\")\n        split_graphs = discover_and_split_parallel_paths(empty_g)\n        self.assertEqual(len(split_graphs), 0, \"Empty graph should result in empty list.\")\n\n    def test_linear_graph(self):\n        # L1 -> L2 -> L3\n        l1, l2, l3 = (Node(\"L1\", NodeType.DATA_LOAD), Node(\"L2\", NodeType.COMPUTE, dependencies=[\"L1\"]), Node(\"L3\", NodeType.MODEL_TRAIN, dependencies=[\"L2\"]))\n        linear_g = TaskGraph(\"linear\")\n        linear_g.add_nodes([l1, l2, l3])\n\n        split_graphs = discover_and_split_parallel_paths(linear_g)\n\n        self.assertListOfGraphStructuresEqual(\n            split_graphs,\n            [(\"linear_expected\", [copy.deepcopy(n) for n in [l1, l2, l3]])],  # Deepcopy nodes for expected structure\n        )\n\n    def test_simple_reconvergence_diamond_graph(self):\n        #   A\n        #  / \\\n        # B   C\n        #  \\ /\n        #   D\n        a, b, c, d = (Node(\"A\", NodeType.DATA_LOAD), Node(\"B\", NodeType.COMPUTE, dependencies=[\"A\"]), Node(\"C\", NodeType.COMPUTE, dependencies=[\"A\"]), Node(\"D\", NodeType.COMPUTE, dependencies=[\"B\", \"C\"]))\n        diamond_g = TaskGraph(\"diamond\")\n        diamond_g.add_nodes([a, b, c, d])\n        split_graphs = discover_and_split_parallel_paths(diamond_g)\n\n        # Expected: (A->B->D) and (A->C->D)\n        exp_nodes1 = [Node(\"A\", NodeType.DATA_LOAD), Node(\"B\", NodeType.COMPUTE, dependencies=[\"A\"]), Node(\"D\", NodeType.COMPUTE, dependencies=[\"B\"])]\n        exp_nodes2 = [Node(\"A\", NodeType.DATA_LOAD), Node(\"C\", NodeType.COMPUTE, dependencies=[\"A\"]), Node(\"D\", NodeType.COMPUTE, dependencies=[\"C\"])]\n\n        self.assertListOfGraphStructuresEqual(split_graphs, [(\"path_abd\", exp_nodes1), (\"path_acd\", exp_nodes2)])\n\n    def test_fan_out_only_graph(self):\n        # A -> B (exit1)\n        #   -> C (exit2)\n        s_a, s_b_exit1, s_c_exit2 = Node(\"S_A\", NodeType.DATA_LOAD), Node(\"S_B_exit1\", NodeType.COMPUTE, dependencies=[\"S_A\"]), Node(\"S_C_exit2\", NodeType.COMPUTE, dependencies=[\"S_A\"])\n        fanout_g = TaskGraph(\"fanout_only\")\n        fanout_g.add_nodes([s_a, s_b_exit1, s_c_exit2])\n        split_graphs = discover_and_split_parallel_paths(fanout_g)\n\n        exp_nodes1 = [Node(\"S_A\", NodeType.DATA_LOAD), Node(\"S_B_exit1\", NodeType.COMPUTE, dependencies=[\"S_A\"])]\n        exp_nodes2 = [Node(\"S_A\", NodeType.DATA_LOAD), Node(\"S_C_exit2\", NodeType.COMPUTE, dependencies=[\"S_A\"])]\n\n        self.assertListOfGraphStructuresEqual(split_graphs, [(\"path_sa_sb\", exp_nodes1), (\"path_sa_sc\", exp_nodes2)])\n\n    def test_ex1_reconverge_from_prompt(self):\n        # A -> B \\\n        #         -> C -> D_ex1 -> E_ex1\n        # A1-> B1/\n        node_a_orig = Node(node_id=\"A\", node_type=NodeType.DATA_LOAD)\n        node_b_orig = Node(node_id=\"B\", node_type=NodeType.COMPUTE, dependencies=[\"A\"])\n        node_a1_orig = Node(node_id=\"A1\", node_type=NodeType.DATA_LOAD)\n        node_b1_orig = Node(node_id=\"B1\", node_type=NodeType.COMPUTE, dependencies=[\"A1\"])\n        node_c_orig = Node(node_id=\"C\", node_type=NodeType.COMPUTE, dependencies=[\"B\", \"B1\"])\n        node_d_ex1_orig = Node(node_id=\"D_ex1\", node_type=NodeType.COMPUTE, dependencies=[\"C\"])\n        node_e_ex1_orig = Node(node_id=\"E_ex1\", node_type=NodeType.MODEL_TRAIN, dependencies=[\"D_ex1\"])\n\n        original_graph_ex1 = TaskGraph(graph_id=\"ex1_reconverge\")\n        original_graph_ex1.add_nodes([node_a_orig, node_b_orig, node_a1_orig, node_b1_orig, node_c_orig, node_d_ex1_orig, node_e_ex1_orig])\n\n        split_graphs = discover_and_split_parallel_paths(original_graph_ex1)\n\n        # Expected path 1: A -> B -> C -> D_ex1 -> E_ex1\n        exp1_nodes = [Node(\"A\", NodeType.DATA_LOAD), Node(\"B\", NodeType.COMPUTE, dependencies=[\"A\"]), Node(\"C\", NodeType.COMPUTE, dependencies=[\"B\"]), Node(\"D_ex1\", NodeType.COMPUTE, dependencies=[\"C\"]), Node(\"E_ex1\", NodeType.MODEL_TRAIN, dependencies=[\"D_ex1\"])]\n        # Expected path 2: A1 -> B1 -> C -> D_ex1 -> E_ex1\n        exp2_nodes = [Node(\"A1\", NodeType.DATA_LOAD), Node(\"B1\", NodeType.COMPUTE, dependencies=[\"A1\"]), Node(\"C\", NodeType.COMPUTE, dependencies=[\"B1\"]), Node(\"D_ex1\", NodeType.COMPUTE, dependencies=[\"C\"]), Node(\"E_ex1\", NodeType.MODEL_TRAIN, dependencies=[\"D_ex1\"])]\n        self.assertListOfGraphStructuresEqual(split_graphs, [(\"path1_ex1\", exp1_nodes), (\"path2_ex1\", exp2_nodes)], msg=\"TestEx1Reconverge\")\n\n    def test_ex2_complex_from_prompt(self):\n        #      X -> P1 \\\n        #               -> M1 -> Z -> J1 -> K1 (exit1)\n        #      Y -> P2 /        |\n        #                       -> J2 -> K2 (exit2)\n        #      P3 ---------------^ (P3 connects to Z)\n        nx = Node(\"X\", NodeType.DATA_LOAD)\n        ny = Node(\"Y\", NodeType.DATA_LOAD)\n        np1 = Node(\"P1\", NodeType.COMPUTE, dependencies=[\"X\"])\n        np2 = Node(\"P2\", NodeType.COMPUTE, dependencies=[\"Y\"])\n        nm1 = Node(\"M1\", NodeType.COMPUTE, dependencies=[\"P1\", \"P2\"])\n        np3 = Node(\"P3\", NodeType.DATA_LOAD)\n        nz = Node(\"Z\", NodeType.COMPUTE, dependencies=[\"M1\", \"P3\"])\n        nj1 = Node(\"J1\", NodeType.COMPUTE, dependencies=[\"Z\"])\n        nj2 = Node(\"J2\", NodeType.COMPUTE, dependencies=[\"Z\"])\n        nk1 = Node(\"K1\", NodeType.MODEL_TRAIN, dependencies=[\"J1\"])  # Exit1\n        nk2 = Node(\"K2\", NodeType.MODEL_TRAIN, dependencies=[\"J2\"])  # Exit2\n\n        complex_graph = TaskGraph(\"ex2_complex\")\n        complex_graph.add_nodes([nx, ny, np1, np2, nm1, np3, nz, nj1, nj2, nk1, nk2])\n        split_graphs = discover_and_split_parallel_paths(complex_graph)\n\n        # Expected subgraphs (6 of them after full decomposition):\n        # Path Group 1 (to K1):\n        # 1. X->P1->M1->Z->J1->K1\n        exp1_k1_nodes = [\n            Node(\"X\", NodeType.DATA_LOAD),\n            Node(\"P1\", NodeType.COMPUTE, dependencies=[\"X\"]),\n            Node(\"M1\", NodeType.COMPUTE, dependencies=[\"P1\"]),\n            Node(\"Z\", NodeType.COMPUTE, dependencies=[\"M1\"]),\n            Node(\"J1\", NodeType.COMPUTE, dependencies=[\"Z\"]),\n            Node(\"K1\", NodeType.MODEL_TRAIN, dependencies=[\"J1\"]),\n        ]\n        # 2. Y->P2->M1->Z->J1->K1\n        exp2_k1_nodes = [\n            Node(\"Y\", NodeType.DATA_LOAD),\n            Node(\"P2\", NodeType.COMPUTE, dependencies=[\"Y\"]),\n            Node(\"M1\", NodeType.COMPUTE, dependencies=[\"P2\"]),\n            Node(\"Z\", NodeType.COMPUTE, dependencies=[\"M1\"]),\n            Node(\"J1\", NodeType.COMPUTE, dependencies=[\"Z\"]),\n            Node(\"K1\", NodeType.MODEL_TRAIN, dependencies=[\"J1\"]),\n        ]\n        # 3. P3->Z->J1->K1\n        exp3_k1_nodes = [Node(\"P3\", NodeType.DATA_LOAD), Node(\"Z\", NodeType.COMPUTE, dependencies=[\"P3\"]), Node(\"J1\", NodeType.COMPUTE, dependencies=[\"Z\"]), Node(\"K1\", NodeType.MODEL_TRAIN, dependencies=[\"J1\"])]\n\n        # Path Group 2 (to K2):\n        # 4. X->P1->M1->Z->J2->K2\n        exp1_k2_nodes = [\n            Node(\"X\", NodeType.DATA_LOAD),\n            Node(\"P1\", NodeType.COMPUTE, dependencies=[\"X\"]),\n            Node(\"M1\", NodeType.COMPUTE, dependencies=[\"P1\"]),\n            Node(\"Z\", NodeType.COMPUTE, dependencies=[\"M1\"]),\n            Node(\"J2\", NodeType.COMPUTE, dependencies=[\"Z\"]),\n            Node(\"K2\", NodeType.MODEL_TRAIN, dependencies=[\"J2\"]),\n        ]\n        # 5. Y->P2->M1->Z->J2->K2\n        exp2_k2_nodes = [\n            Node(\"Y\", NodeType.DATA_LOAD),\n            Node(\"P2\", NodeType.COMPUTE, dependencies=[\"Y\"]),\n            Node(\"M1\", NodeType.COMPUTE, dependencies=[\"P2\"]),\n            Node(\"Z\", NodeType.COMPUTE, dependencies=[\"M1\"]),\n            Node(\"J2\", NodeType.COMPUTE, dependencies=[\"Z\"]),\n            Node(\"K2\", NodeType.MODEL_TRAIN, dependencies=[\"J2\"]),\n        ]\n        # 6. P3->Z->J2->K2\n        exp3_k2_nodes = [Node(\"P3\", NodeType.DATA_LOAD), Node(\"Z\", NodeType.COMPUTE, dependencies=[\"P3\"]), Node(\"J2\", NodeType.COMPUTE, dependencies=[\"Z\"]), Node(\"K2\", NodeType.MODEL_TRAIN, dependencies=[\"J2\"])]\n\n        self.assertListOfGraphStructuresEqual(split_graphs, [(\"xp1m1zj1k1\", exp1_k1_nodes), (\"yp2m1zj1k1\", exp2_k1_nodes), (\"p3zj1k1\", exp3_k1_nodes), (\"xp1m1zj2k2\", exp1_k2_nodes), (\"yp2m1zj2k2\", exp2_k2_nodes), (\"p3zj2k2\", exp3_k2_nodes)], msg=\"TestEx2Complex\")\n\n\nif __name__ == \"__main__\":\n    unittest.main(argv=[\"first-arg-is-ignored\"], exit=False, verbosity=2)\n"
  },
  {
    "path": "tests/dag_worker/test_dag_worker.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\nimport asyncio\nimport ray\nimport torch\nimport torch.distributed as dist\nfrom tensordict import TensorDict\nfrom typing import List, Optional, Dict\n\nfrom siirl.workers.databuffer.data_buffer import DataBuffer\nfrom tests.data_buffer.test_data_buffer import compare_dataprotos\n\n\n# Mock Tokenizer for testing purposes\nclass MockTokenizer:\n    def __init__(self, pad_token_id=0):\n        self.pad_token_id = pad_token_id\n        self.padding_side = \"right\"\n\n\n@ray.remote\nclass MockDAGWorker:\n    \"\"\"\n    A mock DAGWorker that only contains the logic needed to test put/get.\n    It does not load a model or execute an actual computation graph.\n    \"\"\"\n\n    def __init__(self, rank: int, world_size: int, data_buffers: List[ray.actor.ActorHandle]):\n        self._rank = rank\n        self._world_size = world_size\n        self.data_buffers = data_buffers\n        self.tokenizer = MockTokenizer()\n\n        # Mock the distributed setup for the actor\n        dist.init_process_group(\n            backend=\"gloo\",  # Use gloo for CPU-based testing\n            init_method=f\"tcp://127.0.0.1:29501\",\n            world_size=self._world_size,\n            rank=self._rank,\n        )\n\n    async def put_data_to_buffers(self, key: str, data: TensorDict, source_dp_size: int, dest_dp_size: int):\n        # This is a copy of the logic from the actual DAGWorker\n        data.meta_info[\"padding_values\"] = {\n            \"input_ids\": self.tokenizer.pad_token_id,\n            \"responses\": self.tokenizer.pad_token_id,\n            \"labels\": -100,\n            \"attention_mask\": 0,\n            \"response_mask\": 0,\n        }\n        data.meta_info[\"padding_side\"] = self.tokenizer.padding_side\n\n        if source_dp_size == dest_dp_size:\n            obj_ref = ray.put(data)\n            await self.data_buffers[0].put.remote(key, [obj_ref])\n        else:\n            num_physical_buffers = len(self.data_buffers)\n            chunks = data.chunk(chunks=num_physical_buffers)\n            put_futures = [buffer.put.remote(key, chunk) for buffer, chunk in zip(self.data_buffers, chunks)]\n            await asyncio.gather(*put_futures)\n\n    async def get_data_from_buffers(self, key: str, my_current_dp_rank: int, my_current_dp_world_size: int) -> Optional[TensorDict]:\n        # This is a copy of the logic from the actual DAGWorker\n        if not self.data_buffers:\n            return None\n\n        first_item = await self.data_buffers[0].get.remote(key, my_current_dp_rank, my_current_dp_world_size)\n\n        if first_item is None:\n            return None\n\n        if isinstance(first_item, ray.ObjectRef):\n            return await first_item\n\n        elif isinstance(first_item, TensorDict):\n            my_sub_chunks = [first_item]\n            get_futures = [buffer.get.remote(key, my_current_dp_rank, my_current_dp_world_size) for buffer in self.data_buffers[1:]]\n            other_sub_chunks = await asyncio.gather(*get_futures)\n\n            for sub_chunk in other_sub_chunks:\n                if not isinstance(sub_chunk, TensorDict):\n                    return None\n                my_sub_chunks.append(sub_chunk)\n\n            return TensorDict.concat(my_sub_chunks)\n\n        return None\n\n    def get_rank(self):\n        return self._rank\n\n    def barrier(self):\n        dist.barrier()\n\n\ndef _create_test_dp(batch_size: int, seq_len: int, meta: Optional[Dict] = None) -> TensorDict:\n    \"\"\"Creates a sample TensorDict for testing.\"\"\"\n    tensors = {\n        \"input_ids\": torch.randint(1, 100, (batch_size, seq_len)),\n        \"attention_mask\": torch.ones(batch_size, seq_len, dtype=torch.long),\n    }\n    td = TensorDict(tensors, batch_size=[batch_size])\n    return TensorDict(batch=td, meta_info=meta or {})\n\n\nclass TestDAGWorkerDataFlow(unittest.IsolatedAsyncioTestCase):\n    @classmethod\n    def setUpClass(cls):\n        if not ray.is_initialized():\n            ray.init(num_cpus=8, ignore_reinit_error=True, logging_level=\"error\")\n\n    @classmethod\n    def tearDownClass(cls):\n        if ray.is_initialized():\n            ray.shutdown()\n\n    async def test_put_get_flow_sharded(self):\n        \"\"\"\n        Tests the data flow for sharded storage (source_dp != dest_dp).\n        - A single worker (rank 0) puts a sharded TensorDict.\n        - All workers get their corresponding data and reconstruct it.\n        \"\"\"\n        num_buffers = 2\n        num_workers = 4\n        source_dp, dest_dp = 2, 4\n        key = \"sharded_key\"\n\n        # 1. Setup Actors\n        buffers = [DataBuffer.remote(buffer_id=i) for i in range(num_buffers)]\n        workers = [MockDAGWorker.remote(rank=i, world_size=num_workers, data_buffers=buffers) for i in range(num_workers)]\n        await asyncio.sleep(1)  # Wait for actors to initialize dist group\n\n        # 2. Prepare Data\n        # Total batch size should be divisible by dest_dp and num_buffers\n        total_batch_size = dest_dp * num_buffers\n        full_dp = _create_test_dp(total_batch_size, seq_len=10)\n\n        # 3. Leader (worker 0) puts data\n        # Data is sharded across the 2 buffers\n        await workers[0].put_data_to_buffers.remote(key, full_dp, source_dp_size=source_dp, dest_dp_size=dest_dp)\n\n        # 4. All workers synchronize\n        await asyncio.gather(*[w.barrier.remote() for w in workers])\n\n        # 5. All workers get their data slice\n        get_futures = [w.get_data_from_buffers.remote(key, my_current_dp_rank=await w.get_rank.remote(), my_current_dp_world_size=dest_dp) for w in workers]\n        results = await asyncio.gather(*get_futures)\n\n        # 6. Verify results\n        # In sharded mode, data is distributed in an interleaved manner. We must manually reconstruct\n        # the expected shard for each worker.\n        buffer_chunks = full_dp.chunk(chunks=num_buffers)\n        expected_shards_by_worker = []\n        for worker_rank in range(dest_dp):\n            sub_chunks_for_worker = []\n            for buffer_chunk in buffer_chunks:\n                # The chunk from each buffer will be re-sharded for the target worker.\n                sub_sub_chunks = buffer_chunk.chunk(chunks=dest_dp)\n                sub_chunks_for_worker.append(sub_sub_chunks[worker_rank])\n            # The worker's final data is the concatenation of the sub-shards it receives from all buffers.\n            expected_shards_by_worker.append(TensorDict.concat(sub_chunks_for_worker))\n\n        self.assertEqual(len(results), len(expected_shards_by_worker))\n\n        for i, result_dp in enumerate(results):\n            self.assertIsNotNone(result_dp, f\"Worker {i} received None\")\n            expected_dp = expected_shards_by_worker[i]\n            self.assertTrue(compare_dataprotos(expected_dp, result_dp, check_meta=False), f\"TensorDict for worker {i} does not match the expected interleaved shard. Expected size {len(expected_dp)}, got {len(result_dp)}\")\n\n        # Cleanup\n        for actor in workers + buffers:\n            ray.kill(actor, no_restart=True)\n\n    async def test_put_get_flow_object_ref(self):\n        \"\"\"\n        Tests the data flow for ObjectRef storage (source_dp == dest_dp).\n        - A single worker (rank 0) puts a TensorDict, which becomes an ObjectRef.\n        - All workers get and resolve the same ObjectRef.\n        \"\"\"\n        num_buffers = 1  # Only one buffer is used in this case\n        num_workers = 4\n        source_dp, dest_dp = 4, 4\n        key = \"object_ref_key\"\n\n        # 1. Setup Actors\n        buffers = [DataBuffer.remote(buffer_id=0)]\n        workers = [MockDAGWorker.remote(rank=i, world_size=num_workers, data_buffers=buffers) for i in range(num_workers)]\n        await asyncio.sleep(1)\n\n        # 2. Prepare Data\n        full_dp = _create_test_dp(batch_size=8, seq_len=12)\n\n        # 3. Leader (worker 0) puts data\n        await workers[0].put_data_to_buffers.remote(key, full_dp, source_dp_size=source_dp, dest_dp_size=dest_dp)\n\n        # 4. All workers synchronize\n        await asyncio.gather(*[w.barrier.remote() for w in workers])\n\n        # 5. All workers get data and verify\n        get_futures = [w.get_data_from_buffers.remote(key, my_current_dp_rank=await w.get_rank.remote(), my_current_dp_world_size=dest_dp) for w in workers]\n        results = await asyncio.gather(*get_futures)\n\n        # In the ObjectRef case, every worker gets the same data.\n        for i, result_dp in enumerate(results):\n            self.assertIsNotNone(result_dp, f\"Worker {i} received None\")\n            self.assertTrue(compare_dataprotos(full_dp, result_dp, check_meta=False), f\"TensorDict for worker {i} does not match the original\")\n\n        # Cleanup\n        for actor in workers + buffers:\n            ray.kill(actor)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/dag_worker/test_dapo_merge.py",
    "content": "\"\"\"\nUnit tests for DAPO filtering and merging logic.\nThis test simulates the real scenario from training logs to quickly verify fixes.\n\"\"\"\nimport torch\nimport numpy as np\nfrom tensordict import TensorDict\nfrom tensordict.tensorclass import NonTensorData\nfrom siirl.data_coordinator.sample import filter_tensordict\n\n\ndef get_unpacked_data(value):\n    \"\"\"Safely get data from a value that might be NonTensorData or a raw type.\"\"\"\n    if isinstance(value, NonTensorData):\n        return value.data\n    return value\n\n\ndef create_mock_batch(batch_size=1024):\n    \"\"\"\n    Create a mock batch that simulates the real data structure in DAPO.\n    \n    Includes:\n    - Batched tensor fields (input_ids, attention_mask, etc.)\n    - Batched NonTensorData fields (data_source, ability, etc.)\n    - Metadata NonTensorData fields (eos_token_id, pad_token_id, etc.)\n    \"\"\"\n    batch_dict = {\n        # Tensor fields (batched)\n        'input_ids': torch.randint(0, 1000, (batch_size, 10240)),\n        'attention_mask': torch.ones(batch_size, 10240, dtype=torch.int64),\n        'position_ids': torch.arange(10240).unsqueeze(0).repeat(batch_size, 1),\n        'uid': torch.arange(batch_size, dtype=torch.int64),\n        'prompts': torch.randint(0, 1000, (batch_size, 2048)),\n        'responses': torch.randint(0, 1000, (batch_size, 8192)),\n        'response_mask': torch.ones(batch_size, 8192, dtype=torch.int64),\n        'token_level_scores': torch.randn(batch_size, 8192),\n        'score': torch.randn(batch_size),\n        'acc': torch.randint(0, 2, (batch_size,), dtype=torch.bool),\n        'token_level_rewards': torch.randn(batch_size, 8192),\n        \n        # Batched NonTensorData fields (length = batch_size)\n        'data_source': NonTensorData(\n            data=np.array(['gsm8k'] * batch_size),\n            batch_size=[batch_size]\n        ),\n        'ability': NonTensorData(\n            data=np.array(['math'] * batch_size),\n            batch_size=[batch_size]\n        ),\n        'reward_model': NonTensorData(\n            data=np.array([{'name': 'rm1'}] * batch_size),\n            batch_size=[batch_size]\n        ),\n        'extra_info': NonTensorData(\n            data=np.array([{}] * batch_size),\n            batch_size=[batch_size]\n        ),\n        'pred': NonTensorData(\n            data=np.array(['answer'] * batch_size),\n            batch_size=[batch_size]\n        ),\n        'global_token_num': NonTensorData(\n            data=[10240] * batch_size,\n            batch_size=[batch_size]\n        ),\n        \n        # Metadata NonTensorData fields (length != batch_size)\n        'eos_token_id': NonTensorData(\n            data=[151643, 151645],  # Length 2, not batch_size\n            batch_size=[batch_size]\n        ),\n        'pad_token_id': NonTensorData(\n            data=151643,  # Scalar\n            batch_size=[batch_size]\n        ),\n        'total_input_tokens': NonTensorData(\n            data=168784,  # Scalar\n            batch_size=[batch_size]\n        ),\n        'total_output_tokens': NonTensorData(\n            data=966975,  # Scalar\n            batch_size=[batch_size]\n        ),\n    }\n    \n    return TensorDict(batch_dict, batch_size=batch_size)\n\n\ndef test_filter_tensordict():\n    \"\"\"Test that filter_tensordict correctly handles batched data and metadata.\"\"\"\n    print(\"\\n=== Testing filter_tensordict ===\")\n    \n    # Create mock batch\n    batch = create_mock_batch(batch_size=1024)\n    print(f\"Original batch size: {len(batch)}\")\n    \n    # Simulate filtering (keeping ~50% of samples, like in the logs)\n    indices = list(range(16, 1024, 2))  # Keep every other sample starting from 16\n    print(f\"Filtering to {len(indices)} samples\")\n    \n    # Filter\n    filtered_batch = filter_tensordict(batch, indices)\n    print(f\"Filtered batch size: {len(filtered_batch)}\")\n    \n    # Verify tensor fields are correctly filtered\n    assert len(filtered_batch['input_ids']) == len(indices)\n    assert len(filtered_batch['uid']) == len(indices)\n    \n    # Verify batched NonTensorData fields are correctly filtered\n    assert len(get_unpacked_data(filtered_batch['data_source'])) == len(indices)\n    assert len(get_unpacked_data(filtered_batch['ability'])) == len(indices)\n    assert len(get_unpacked_data(filtered_batch['global_token_num'])) == len(indices)\n    \n    # Verify metadata NonTensorData fields are preserved (not filtered)\n    assert len(get_unpacked_data(filtered_batch['eos_token_id'])) == 2, \"Metadata should be preserved\"\n    assert get_unpacked_data(filtered_batch['pad_token_id']) == 151643, \"Metadata should be preserved\"\n    \n    print(\"✅ filter_tensordict test passed!\")\n    return filtered_batch\n\n\ndef test_merge_cached_and_new():\n    \"\"\"\n    Test merging cached filtered batch with new filtered batch.\n    This simulates the real merge scenario in dagworker.py.\n    \"\"\"\n    print(\"\\n=== Testing merge logic ===\")\n    \n    # Create first batch and filter (this will be cached)\n    batch1 = create_mock_batch(batch_size=1024)\n    indices1 = list(range(16, 512))  # 496 samples\n    cached_batch = filter_tensordict(batch1, indices1)\n    print(f\"Cached batch size: {len(cached_batch)}\")\n    \n    # Create second batch and filter (this is the new batch)\n    batch2 = create_mock_batch(batch_size=1024)\n    indices2 = list(range(20, 540))  # 520 samples\n    new_batch = filter_tensordict(batch2, indices2)\n    print(f\"New batch size: {len(new_batch)}\")\n    \n    # Now merge them (simulate dagworker merge logic)\n    cache_size = len(cached_batch)\n    new_size = len(new_batch)\n    merged_size = cache_size + new_size\n    print(f\"Expected merged size: {merged_size}\")\n    \n    merged_dict = {}\n    for key in cached_batch.keys():\n        cache_val = cached_batch[key]\n        new_val = new_batch[key]\n        \n        # Debug: print info for NonTensorData fields\n        if isinstance(cache_val, NonTensorData) and key in ['data_source', 'ability', 'eos_token_id']:\n            print(f\"\\nDEBUG {key}:\")\n            print(f\"  cache_val type: {type(cache_val)}, data type: {type(cache_val.data)}\")\n            if hasattr(cache_val.data, '__len__'):\n                print(f\"  cache_val.data length: {len(cache_val.data)}, cache_size: {cache_size}\")\n            print(f\"  new_val type: {type(new_val)}, data type: {type(new_val.data) if isinstance(new_val, NonTensorData) else 'N/A'}\")\n            if isinstance(new_val, NonTensorData) and hasattr(new_val.data, '__len__'):\n                print(f\"  new_val.data length: {len(new_val.data)}, new_size: {new_size}\")\n        \n        if isinstance(cache_val, torch.Tensor):\n            # Merge tensors\n            merged_dict[key] = torch.cat([cache_val, new_val], dim=0)\n        else:\n            # Handle NonTensorData or raw types (TensorDict may unpack NonTensorData)\n            # Extract the actual data\n            if isinstance(cache_val, NonTensorData):\n                cache_data = cache_val.data\n            else:\n                cache_data = cache_val\n            \n            if isinstance(new_val, NonTensorData):\n                new_data = new_val.data\n            else:\n                new_data = new_val\n            \n            # Determine if it's batched data or metadata based on length\n            if isinstance(cache_data, np.ndarray):\n                cache_len = len(cache_data)\n                if cache_len == cache_size:\n                    # Batched np.ndarray - merge\n                    new_arr = new_data if isinstance(new_data, np.ndarray) else np.array(new_data)\n                    merged_data = np.concatenate([cache_data, new_arr], axis=0)\n                    merged_dict[key] = NonTensorData(data=merged_data, batch_size=[merged_size])\n                else:\n                    # Metadata - keep as NonTensorData\n                    merged_dict[key] = NonTensorData(data=new_data, batch_size=[merged_size])\n            elif isinstance(cache_data, (list, tuple)):\n                cache_len = len(cache_data)\n                if cache_len == cache_size:\n                    # Batched list - merge\n                    new_list = new_data if isinstance(new_data, (list, tuple)) else [new_data] * new_size\n                    merged_data = list(cache_data) + list(new_list)\n                    merged_dict[key] = NonTensorData(data=merged_data, batch_size=[merged_size])\n                else:\n                    # Metadata - keep as NonTensorData\n                    merged_dict[key] = NonTensorData(data=new_data, batch_size=[merged_size])\n            else:\n                # Scalar - always metadata\n                merged_dict[key] = NonTensorData(data=new_data, batch_size=[merged_size])\n    \n    # Verify all fields before creating TensorDict\n    print(\"\\n--- Pre-creation validation ---\")\n    for key, value in merged_dict.items():\n        if isinstance(value, torch.Tensor):\n            val_len = value.shape[0]\n            status = \"✅\" if val_len == merged_size else \"❌\"\n            print(f\"{status} {key}: Tensor, shape={value.shape}\")\n            assert val_len == merged_size, f\"Tensor {key} has wrong size: {val_len} != {merged_size}\"\n        elif isinstance(value, NonTensorData):\n            if hasattr(value.data, '__len__') and not isinstance(value.data, str):\n                val_len = len(value.data)\n                # For batched data, length should match merged_size\n                # For metadata, length can be different\n                if val_len == merged_size:\n                    print(f\"✅ {key}: NonTensorData(batched), len={val_len}\")\n                else:\n                    print(f\"✅ {key}: NonTensorData(metadata), len={val_len}\")\n            else:\n                print(f\"✅ {key}: NonTensorData(scalar)\")\n        else:\n            print(f\"❌ {key}: Raw type {type(value)} - should be wrapped!\")\n            raise ValueError(f\"Field {key} is not properly wrapped\")\n    \n    # Try to create TensorDict\n    print(\"\\n--- Creating TensorDict ---\")\n    try:\n        merged_batch = TensorDict(merged_dict, batch_size=merged_size)\n        print(f\"✅ Successfully created merged TensorDict with batch_size={merged_size}\")\n        \n        # Verify the result\n        assert len(merged_batch) == merged_size\n        assert len(merged_batch['input_ids']) == merged_size\n        assert len(get_unpacked_data(merged_batch['data_source'])) == merged_size\n        \n        # Verify metadata is preserved\n        assert len(get_unpacked_data(merged_batch['eos_token_id'])) == 2\n        assert get_unpacked_data(merged_batch['pad_token_id']) == 151643\n        \n        print(\"✅ Merge test passed!\")\n        return merged_batch\n        \n    except Exception as e:\n        print(f\"❌ Failed to create TensorDict: {e}\")\n        \n        # Print detailed error info\n        print(\"\\n--- Error Details ---\")\n        for k, v in merged_dict.items():\n            v_type = type(v).__name__\n            v_shape = getattr(v, 'shape', 'None')\n            v_dtype = getattr(v, 'dtype', 'None')\n            if isinstance(v, NonTensorData):\n                v_shape = f\"({len(v.data)},)\" if hasattr(v.data, '__len__') else 'scalar'\n                v_dtype = type(v.data).__name__\n            print(f\"  '{k}' -> type={v_type}, shape={v_shape}, dtype={v_dtype}\")\n        \n        raise\n\n\ndef test_full_dapo_workflow():\n    \"\"\"\n    Test the complete DAPO workflow:\n    1. First rollout + filter -> insufficient samples -> cache\n    2. Second rollout + filter -> merge with cache -> sufficient samples\n    \"\"\"\n    print(\"\\n=== Testing full DAPO workflow ===\")\n    \n    # Step 1: First rollout\n    print(\"\\n--- Step 1: First rollout ---\")\n    batch1 = create_mock_batch(batch_size=1024)\n    indices1 = list(range(16, 512))  # 496 samples (insufficient)\n    filtered1 = filter_tensordict(batch1, indices1)\n    print(f\"First rollout: {len(filtered1)} samples (insufficient)\")\n    \n    # Cache it\n    cached_batch = filtered1\n    \n    # Step 2: Second rollout\n    print(\"\\n--- Step 2: Second rollout ---\")\n    batch2 = create_mock_batch(batch_size=1024)\n    indices2 = list(range(20, 540))  # 520 samples\n    filtered2 = filter_tensordict(batch2, indices2)\n    print(f\"Second rollout: {len(filtered2)} samples\")\n    \n    # Step 3: Merge\n    print(\"\\n--- Step 3: Merge cached and new ---\")\n    merged = test_merge_cached_and_new()\n    \n    print(f\"\\n✅ Full workflow test passed! Final batch size: {len(merged)}\")\n\n\nif __name__ == '__main__':\n    print(\"=\" * 60)\n    print(\"DAPO Merge Logic Unit Tests\")\n    print(\"=\" * 60)\n    \n    try:\n        # Test 1: Filter\n        filtered = test_filter_tensordict()\n        \n        # Test 2: Merge\n        merged = test_merge_cached_and_new()\n        \n        # Test 3: Full workflow\n        test_full_dapo_workflow()\n        \n        print(\"\\n\" + \"=\" * 60)\n        print(\"✅ ALL TESTS PASSED!\")\n        print(\"=\" * 60)\n        \n    except Exception as e:\n        print(\"\\n\" + \"=\" * 60)\n        print(f\"❌ TEST FAILED: {e}\")\n        print(\"=\" * 60)\n        import traceback\n        traceback.print_exc()\n        exit(1)\n\n"
  },
  {
    "path": "tests/dag_worker/test_dapo_pipeline.py",
    "content": "import pytest\nimport torch\nimport numpy as np\nfrom tensordict import TensorDict\nfrom tensordict.tensorclass import NonTensorData\nimport logging\n\n# Configure logging for tests\nlogging.basicConfig(level=logging.INFO, format='%(levelname)s:%(name)s:%(message)s')\nlogger = logging.getLogger(__name__)\n\nfrom siirl.data_coordinator.sample import preprocess_dataloader\n\ndef test_preprocess_dataloader():\n    \"\"\"\n    Tests if preprocess_dataloader correctly handles data repetition (n > 1)\n    and creates integer uids.\n    \"\"\"\n    data = {\n        'input_ids': np.array([[1, 2], [3, 4]]),\n        'attention_mask': torch.tensor([[1, 1], [1, 0]]),\n        'data_source': ['d1', 'd2'] \n    }\n    n = 2\n    \n    tensor_dict = preprocess_dataloader(data, n=n)\n    \n    # 1. Check batch size\n    assert tensor_dict.batch_size[0] == 4\n    \n    # 2. Check uid creation and type\n    assert 'uid' in tensor_dict.keys()\n    assert len(tensor_dict['uid']) == 4\n    # uids should be [0, 0, 1, 1] after repeat\n    expected_uids = np.array([0, 0, 1, 1], dtype=np.int64)\n    assert np.array_equal(tensor_dict['uid'], expected_uids)\n\n    # 3. Check numpy array repetition\n    expected_input_ids = np.array([[1, 2], [1, 2], [3, 4], [3, 4]])\n    assert np.array_equal(tensor_dict['input_ids'], expected_input_ids)\n    \n    # 4. Check torch tensor repetition\n    expected_attention_mask = torch.tensor([[1, 1], [1, 1], [1, 0], [1, 0]])\n    assert torch.equal(tensor_dict['attention_mask'], expected_attention_mask)\n    \n    # 5. Check NonTensorData repetition\n    assert isinstance(tensor_dict['data_source'], np.ndarray)\n    expected_data_source = np.array(['d1', 'd1', 'd2', 'd2'])\n    assert np.array_equal(tensor_dict['data_source'], expected_data_source)\n\n\n# Mock DAGWorker for testing postprocess_sampling\nclass MockDAGWorker:\n    def __init__(self, rank):\n        self._rank = rank\n        self.sampling_leftover_cache = None\n\n    # Simplified version of the method for testing purposes\n    def postprocess_sampling(self, config, batch: TensorDict, filtered_indices: list):\n        \n        # Mock dynamic_sampling behavior\n        filtered_batch = batch[filtered_indices]\n        metrics = {\n            'dapo_sampling/kept_trajectories_ratio': len(filtered_indices) / len(batch),\n            'dapo_sampling/filtered_indices': filtered_indices # Simulate returning indices in metrics\n        }\n\n        # --- This part now mirrors the production code ---\n        \n        # Get the indices from the metrics dict\n        local_filtered_indices = metrics.pop('dapo_sampling/filtered_indices', [])\n\n        if self.sampling_leftover_cache is not None:\n            # Manual merge logic from the actual implementation\n            from tensordict.tensorclass import NonTensorData\n            cache_size = len(self.sampling_leftover_cache)\n            new_size = len(filtered_batch)\n            merged_size = cache_size + new_size\n            \n            merged_dict = {}\n            # Use keys from the new batch as it's guaranteed to be non-empty\n            for key in set(self.sampling_leftover_cache.keys()) | set(filtered_batch.keys()):\n                cache_val = self.sampling_leftover_cache.get(key)\n                new_val = filtered_batch.get(key)\n\n                print(f\"\\n--- Processing key: {key} ---\")\n                \n                if cache_val is None:\n                    merged_dict[key] = new_val\n                    continue\n                if new_val is None:\n                    merged_dict[key] = cache_val\n                    continue\n\n                print(f\"  Cache type: {type(cache_val)}, New type: {type(new_val)}\")\n\n                # --- Final, explicit merge logic ---\n                if isinstance(cache_val, torch.Tensor):\n                    print(\"  Type is torch.Tensor. Concatenating.\")\n                    merged_dict[key] = torch.cat([cache_val, new_val], dim=0)\n                elif isinstance(cache_val, NonTensorData):\n                    # Check the type of the wrapped data\n                    if isinstance(cache_val.data, np.ndarray):\n                        # It's a numpy array wrapped in NonTensorData.\n                        new_val_filtered = np.array(batch[key].data)[local_filtered_indices]\n                        merged_data = np.concatenate([cache_val.data, new_val_filtered], axis=0)\n                        merged_dict[key] = NonTensorData(data=merged_data, batch_size=torch.Size([merged_size]))\n                    elif isinstance(cache_val.data, (list, tuple)):\n                        # It's batched list-like data.\n                        # new_val is already the filtered list\n                        merged_data = cache_val.tolist() + new_val\n                        merged_dict[key] = NonTensorData(data=merged_data, batch_size=torch.Size([merged_size]))\n                    else:\n                        # It's metadata, keep the new value\n                        merged_dict[key] = new_val\n                else:\n                    # Fallback for any other metadata (simple int, str, etc.)\n                    print(f\"  Type is other metadata ({type(cache_val)}). Keeping new value.\")\n                    merged_dict[key] = new_val\n            \n            print(\"--- Merge Complete ---\")\n            filtered_batch = TensorDict(merged_dict, batch_size=torch.Size([merged_size]))\n            self.sampling_leftover_cache = None\n        \n        # Mock distributed aggregation and decision\n        total_samples = len(filtered_batch) # Simplified for single-process test\n        target_total_samples = config['target_total_samples']\n\n        if total_samples < target_total_samples:\n            self.sampling_leftover_cache = filtered_batch\n            return TensorDict({}, batch_size=(0,)), metrics\n        else:\n            return filtered_batch, metrics\n\n@pytest.fixture\ndef sample_batch():\n    \"\"\"Provides a sample TensorDict for testing.\"\"\"\n    return TensorDict({\n        'uid': torch.arange(8),\n        'input_ids': torch.randn(8, 10),\n        'data_source': NonTensorData(np.array([f\"src_{i}\" for i in range(8)]), batch_size=8)\n    }, batch_size=8)\n\n\ndef test_postprocess_sampling_caching(sample_batch):\n    \"\"\"\n    Tests if postprocess_sampling correctly caches data when there are insufficient samples.\n    \"\"\"\n    worker = MockDAGWorker(rank=0)\n    config = {'target_total_samples': 10}\n    \n    # First pass: not enough samples, should cache and return empty\n    filtered_indices = [1, 3, 5]\n    result_batch, _ = worker.postprocess_sampling(config, sample_batch, filtered_indices)\n    \n    assert len(result_batch) == 0\n    assert worker.sampling_leftover_cache is not None\n    assert len(worker.sampling_leftover_cache) == 3\n    assert torch.equal(worker.sampling_leftover_cache['uid'], torch.tensor([1, 3, 5]))\n\ndef test_postprocess_sampling_merging(sample_batch):\n    \"\"\"\n    Tests if postprocess_sampling correctly merges cached data with new data.\n    \"\"\"\n    worker = MockDAGWorker(rank=0)\n    config = {'target_total_samples': 5}\n    \n    # First, cache some data\n    worker.sampling_leftover_cache = TensorDict({\n        'uid': torch.tensor([10, 20]),\n        'input_ids': torch.randn(2, 10),\n        'data_source': NonTensorData(np.array(['cached_1', 'cached_2']), batch_size=2),\n        'metadata_field': NonTensorData(data=1, batch_size=torch.Size([2])) # Wrap metadata in NonTensorData\n    }, batch_size=2)\n    \n    # Add metadata to the new batch as well\n    sample_batch['metadata_field'] = NonTensorData(data=2, batch_size=sample_batch.batch_size) # New metadata\n    \n    # Second pass: new data should be merged with cache\n    filtered_indices = [0, 2, 4, 6]\n    result_batch, _ = worker.postprocess_sampling(config, sample_batch, filtered_indices)\n    \n    # Total samples (2 cached + 4 new = 6) >= target (5), so should return merged batch\n    assert worker.sampling_leftover_cache is None\n    assert len(result_batch) == 6\n    \n    # Check merged content\n    expected_uids = torch.tensor([10, 20, 0, 2, 4, 6])\n    assert torch.equal(result_batch['uid'], expected_uids)\n    \n    expected_sources = ['cached_1', 'cached_2', 'src_0', 'src_2', 'src_4', 'src_6']\n    assert np.array_equal(result_batch['data_source'].data, np.array(expected_sources))\n    \n    # Check that the metadata from the NEW batch is kept\n    assert result_batch['metadata_field'] == 2\n"
  },
  {
    "path": "tests/data_buffer/detailed_put_performance_test.py",
    "content": "import asyncio\nimport time\nimport ray\nimport torch\nimport numpy as np\nfrom tensordict import TensorDict\nfrom typing import List, Tuple\nimport datetime\nimport uuid\nimport statistics\n\n# Make sure the import path is correct based on your project structure\nfrom siirl.data_coordinator.data_buffer import init_data_coordinator\nfrom siirl.data_coordinator.sample import SampleInfo\n\n\n# ====================================================================\n# Performance Test Configuration\n# ====================================================================\n# --- Data Generation Parameters ---\nTOTAL_SAMPLES = 256  # Total number of samples to generate for the detailed test\nBATCH_SIZE_PER_SAMPLE = 1\nSEQ_LEN = 1024\nEMBED_DIM = 1024\n\n# --- Workload Parameters ---\nNUM_PRODUCERS = 4  # Number of concurrent producers to simulate\nNUM_BUFFERS = NUM_PRODUCERS\n\n# --- Batching Simulation Parameters ---\nBATCH_SIZE_FOR_SIM = 64  # How many samples to group in our simulated batch\n\n\ndef log_with_time(message: str):\n    \"\"\"Prints a message with a timestamp and flushes the output.\"\"\"\n    now = datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S.%f\")[:-3]\n    print(f\"[{now}] {message}\", flush=True)\n\n\ndef create_mock_sample() -> TensorDict:\n    \"\"\"Creates a mock sample (TensorDict) for testing.\"\"\"\n    tensor_data = {\n        \"input_ids\": torch.randint(0, 32000, (BATCH_SIZE_PER_SAMPLE, SEQ_LEN)),\n        \"attention_mask\": torch.ones(BATCH_SIZE_PER_SAMPLE, SEQ_LEN, dtype=torch.long),\n        \"hidden_states\": torch.randn(BATCH_SIZE_PER_SAMPLE, SEQ_LEN, EMBED_DIM),\n    }\n    return TensorDict(tensor_data, batch_size=[BATCH_SIZE_PER_SAMPLE])\n\n\nasync def producer_task_detailed_profile(\n    producer_id: int, \n    coordinator: ray.actor.ActorHandle, \n    num_samples_to_produce: int\n) -> Tuple[List[float], List[float]]:\n    \"\"\"\n    Simulates a producer and records the detailed timings for \n    `ray.put` and `coordinator.put.remote`.\n    \"\"\"\n    ray_put_timings = []\n    coord_put_timings = []\n\n    for _ in range(num_samples_to_produce):\n        sample_data = create_mock_sample()\n        \n        # 1. Profile `ray.put` (Serialization + Local Object Store Write)\n        start_put = time.perf_counter()\n        sample_ref = ray.put(sample_data)\n        end_put = time.perf_counter()\n        ray_put_timings.append(end_put - start_put)\n        \n        sample_info = SampleInfo(uid=str(uuid.uuid4()))\n        \n        # 2. Profile `coordinator.put.remote` (RPC Overhead + Remote Execution)\n        start_coord = time.perf_counter()\n        await coordinator.put.remote(sample_info, sample_ref)\n        end_coord = time.perf_counter()\n        coord_put_timings.append(end_coord - start_coord)\n        \n    return ray_put_timings, coord_put_timings\n\n\ndef analyze_timings(operation_name: str, timings: List[float]):\n    \"\"\"Analyzes and prints statistics for a list of timings.\"\"\"\n    if not timings:\n        log_with_time(f\"  - No data for {operation_name}.\")\n        return\n\n    total_time = sum(timings)\n    mean_time = statistics.mean(timings) * 1000  # ms\n    std_dev = statistics.stdev(timings) * 1000 if len(timings) > 1 else 0.0  # ms\n    \n    log_with_time(f\"  - Analysis for '{operation_name}':\")\n    log_with_time(f\"    - Total Time: {total_time:.4f} seconds for {len(timings)} calls\")\n    log_with_time(f\"    - Average Latency: {mean_time:.4f} ms/call\")\n    log_with_time(f\"    - Standard Deviation: {std_dev:.4f} ms\")\n\n\nasync def main():\n    \"\"\"Main performance test function.\"\"\"\n    if not ray.is_initialized():\n        ray.init(ignore_reinit_error=True, logging_level=\"error\")\n\n    log_with_time(\"=\" * 80)\n    log_with_time(\"  Detailed Performance Profile: ray.put vs coordinator.put.remote\")\n    log_with_time(\"=\" * 80)\n\n    coordinator = init_data_coordinator(NUM_BUFFERS, force_local=True)\n    \n    # --- Part 1: Detailed Profile of the Current \"Sample-by-Sample\" Method ---\n    log_with_time(\"\\n--- Part 1: Profiling Current Sample-by-Sample Approach ---\")\n    samples_per_producer = TOTAL_SAMPLES // NUM_PRODUCERS\n    \n    producer_tasks = []\n    for i in range(NUM_PRODUCERS):\n        task = producer_task_detailed_profile(i, coordinator, samples_per_producer)\n        producer_tasks.append(task)\n            \n    results = await asyncio.gather(*producer_tasks)\n    \n    all_ray_put_timings = [t for res in results for t in res[0]]\n    all_coord_put_timings = [t for res in results for t in res[1]]\n\n    analyze_timings(\"ray.put (Serialization)\", all_ray_put_timings)\n    print() # Spacer\n    analyze_timings(\"coordinator.put.remote (RPC)\", all_coord_put_timings)\n\n    total_ray_put_time = sum(all_ray_put_timings)\n    total_coord_put_time = sum(all_coord_put_timings)\n    \n    log_with_time(\"\\n  - Conclusion for Part 1:\")\n    if total_ray_put_time > total_coord_put_time:\n        ratio = total_ray_put_time / total_coord_put_time if total_coord_put_time > 0 else float('inf')\n        log_with_time(f\"    - Serialization (`ray.put`) is the dominant cost, taking {total_ray_put_time:.4f}s.\")\n        log_with_time(f\"    - `ray.put` is {ratio:.2f}x slower than the coordinator RPC.\")\n    else:\n        ratio = total_coord_put_time / total_ray_put_time if total_ray_put_time > 0 else float('inf')\n        log_with_time(f\"    - RPC (`coordinator.put.remote`) is the dominant cost, taking {total_coord_put_time:.4f}s.\")\n        log_with_time(f\"    - The RPC is {ratio:.2f}x slower than serialization.\")\n\n    # --- Part 2: Local Benchmark of a Potential \"Batched\" Optimization ---\n    log_with_time(\"\\n--- Part 2: Simulating Batched `ray.put` Optimization Potential ---\")\n    \n    # Create a batch of samples locally\n    mock_batch = [create_mock_sample() for _ in range(BATCH_SIZE_FOR_SIM)]\n    \n    # Time a single, batched ray.put\n    start_batch_put = time.perf_counter()\n    ray.put(mock_batch)\n    end_batch_put = time.perf_counter()\n    batched_put_time = end_batch_put - start_batch_put\n    \n    # Get the average time for individual puts from our earlier test\n    avg_single_put_time = statistics.mean(all_ray_put_timings)\n    equivalent_individual_time = avg_single_put_time * BATCH_SIZE_FOR_SIM\n    \n    log_with_time(f\"  - Time to `ray.put` {BATCH_SIZE_FOR_SIM} samples INDIVIDUALLY: {equivalent_individual_time:.4f} seconds (estimated from Part 1)\")\n    log_with_time(f\"  - Time to `ray.put` {BATCH_SIZE_FOR_SIM} samples as a single BATCH: {batched_put_time:.4f} seconds\")\n    \n    if batched_put_time > 0 :\n        speedup_factor = equivalent_individual_time / batched_put_time\n        log_with_time(f\"  - Potential Speedup Factor: {speedup_factor:.2f}x\")\n    log_with_time(\"    - This demonstrates the potential performance gain from reducing serialization overhead.\")\n\n    # --- Cleanup ---\n    log_with_time(\"\\n\\n🧹 Cleaning up resources...\")\n    ray.kill(coordinator, no_restart=True)\n    ray.shutdown()\n    log_with_time(\"=\" * 80)\n    log_with_time(\"               Benchmark Finished\")\n    log_with_time(\"=\" * 80)\n\n\nif __name__ == \"__main__\":\n    asyncio.run(main())\n"
  },
  {
    "path": "tests/data_buffer/performance_test_data_buffer.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport asyncio\nimport time\nimport ray\nimport torch\nimport numpy as np\nfrom tensordict import TensorDict\nfrom typing import List\nimport datetime\nimport random\nimport uuid\n\n# Make sure the import path is correct based on your project structure\nfrom siirl.data_coordinator.data_buffer import init_data_coordinator\nfrom siirl.data_coordinator.sample import SampleInfo\n\n# ====================================================================\n# Performance Test Configuration for New Architecture\n# ====================================================================\n# --- Data Generation Parameters ---\n# Total number of samples to generate (aligned with old test: 200 items * 8 batch/item = 1600 samples)\nTOTAL_SAMPLES = 1600\n# Batch size within each sample (usually 1, simulating a single trajectory)\nBATCH_SIZE_PER_SAMPLE = 1\n# Sequence length \nSEQ_LEN = 1024\n# Embedding dimension \nEMBED_DIM = 1024\n\n# --- Workload Parameters ---\n# Number of concurrent RolloutWorkers to simulate data production\nNUM_PRODUCERS = 8\n# Batch size for a simulated Trainer to get at once\nTRAINER_BATCH_SIZE = 128\n# Number of distributed DataBuffers (should equal NUM_PRODUCERS to simulate each worker having its own local buffer)\nNUM_BUFFERS = NUM_PRODUCERS\n\n\ndef log_with_time(message: str):\n    \"\"\"Prints a message with a timestamp and flushes the output.\"\"\"\n    now = datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S.%f\")[:-3]\n    print(f\"[{now}] {message}\", flush=True)\n\n\ndef create_mock_sample(item_idx: int) -> TensorDict:\n    \"\"\"Creates a mock sample (TensorDict) for testing.\"\"\"\n    tensor_data = {\n        \"input_ids\": torch.randint(0, 32000, (BATCH_SIZE_PER_SAMPLE, SEQ_LEN)),\n        \"attention_mask\": torch.ones(BATCH_SIZE_PER_SAMPLE, SEQ_LEN, dtype=torch.long),\n        \"hidden_states\": torch.randn(BATCH_SIZE_PER_SAMPLE, SEQ_LEN, EMBED_DIM),\n    }\n    return TensorDict(tensor_data, batch_size=[BATCH_SIZE_PER_SAMPLE])\n\n\nasync def producer_task(\n    producer_id: int, \n    coordinator: ray.actor.ActorHandle, \n    num_samples_to_produce: int\n):\n    \"\"\"Simulates the behavior of a RolloutWorker: produce data, store it locally, and register it globally.\"\"\"\n    for i in range(num_samples_to_produce):\n        sample_data = create_mock_sample(producer_id * num_samples_to_produce + i)\n        \n        # 1. Store the sample data in the Ray object store of the current node\n        sample_ref = ray.put(sample_data)\n        \n        # 2. Create metadata\n        sample_info = SampleInfo(\n            agent_group=producer_id,\n            sum_tokens=SEQ_LEN,\n            prompt_length=SEQ_LEN,\n            response_length=0, # Assume response is empty in the producer phase\n            uid=uuid.uuid4().int # Generate a unique integer ID\n        )\n        \n        # 3. Register the metadata and reference with the global DataCoordinator\n        #    The DataCoordinator will automatically handle holding the reference locally\n        await coordinator.put.remote(sample_info, sample_ref)\n\n\nasync def main():\n    \"\"\"Main performance test function.\"\"\"\n    if not ray.is_initialized():\n        ray.init(ignore_reinit_error=True, logging_level=\"error\")\n\n    log_with_time(\"=\" * 80)\n    log_with_time(\"      New DataCoordinator Architecture - Performance Benchmark\")\n    log_with_time(\"=\" * 80)\n    log_with_time(f\"Configuration:\")\n    log_with_time(f\"  - Total Samples to Generate: {TOTAL_SAMPLES}\")\n    log_with_time(f\"  - Concurrent Producers (RolloutWorkers): {NUM_PRODUCERS}\")\n    log_with_time(f\"  - Trainer Batch Size: {TRAINER_BATCH_SIZE}\")\n    log_with_time(f\"  - Distributed Buffers: {NUM_BUFFERS}\")\n    log_with_time(\"-\" * 80)\n\n    # 1. Initialize the data coordination system\n    # For single-machine testing, force_local=True must be set to avoid waiting for multiple nodes\n    coordinator = init_data_coordinator(NUM_BUFFERS, force_local=True)\n    log_with_time(f\"✅ Data system initialized with 1 Coordinator.\")\n\n    # 2. Test concurrent Producer (RolloutWorker) performance\n    log_with_time(\"\\n--- Testing Concurrent Producer (put) Performance ---\")\n    \n    samples_per_producer = TOTAL_SAMPLES // NUM_PRODUCERS\n    if TOTAL_SAMPLES % NUM_PRODUCERS != 0:\n        log_with_time(f\"Warning: Total samples not evenly divisible by producers. Some producers will generate more samples.\")\n    \n    start_time = time.perf_counter()\n    \n    producer_tasks = []\n    for i in range(NUM_PRODUCERS):\n        # The RolloutWorker now only needs to interact with the Coordinator\n        num_to_produce = samples_per_producer + (1 if i < TOTAL_SAMPLES % NUM_PRODUCERS else 0)\n        if num_to_produce > 0:\n            task = producer_task(i, coordinator, num_to_produce)\n            producer_tasks.append(task)\n            \n    await asyncio.gather(*producer_tasks)\n    end_time = time.perf_counter()\n\n    total_put_time = end_time - start_time\n    put_throughput = TOTAL_SAMPLES / total_put_time\n    avg_put_latency = (total_put_time / TOTAL_SAMPLES) * 1000  # ms per sample\n\n    log_with_time(f\"  - Total time to produce and register {TOTAL_SAMPLES} samples: {total_put_time:.4f} seconds\")\n    log_with_time(f\"  - Producer Throughput: {put_throughput:.2f} samples/sec\")\n    log_with_time(f\"  - Average Producer Latency: {avg_put_latency:.4f} ms/sample\")\n    \n    # Verify that all data has been stored in the Coordinator\n    queue_size = await coordinator.get_valid_size.remote()\n    assert queue_size == TOTAL_SAMPLES, f\"Coordinator size mismatch! Expected {TOTAL_SAMPLES}, got {queue_size}\"\n    log_with_time(\"✅ All samples registered in Coordinator.\")\n\n    # 3. Test Consumer (Trainer) performance\n    log_with_time(\"\\n--- Testing Consumer (get_batch) Performance ---\")\n    num_batches_to_get = TOTAL_SAMPLES // TRAINER_BATCH_SIZE\n    log_with_time(f\"Simulating a trainer fetching {num_batches_to_get} batches of size {TRAINER_BATCH_SIZE}.\")\n    \n    consumer_start_time = time.perf_counter()\n    \n    total_retrieved_samples = 0\n    for _ in range(num_batches_to_get):\n        # 1. Get a batch of sample references from the Coordinator (or the values directly due to Ray optimization)\n        batch_refs_or_values = await coordinator.get_batch.remote(TRAINER_BATCH_SIZE)\n        if not batch_refs_or_values:\n            log_with_time(\"  - Coordinator returned empty batch, stopping consumer test.\")\n            break\n            \n        # 2. Differentiate between returned ObjectRefs and resolved values, and handle them accordingly\n        resolved_batch = []\n        refs_to_get = []\n        for item in batch_refs_or_values:\n            if isinstance(item, ray.ObjectRef):\n                refs_to_get.append(item)\n            else:\n                # Ray might return the value directly because the caller is the owner\n                resolved_batch.append(item)\n        \n        # Batch get all ObjectRefs that need to be resolved\n        if refs_to_get:\n            loop = asyncio.get_running_loop()\n            resolved_from_refs = await loop.run_in_executor(None, ray.get, refs_to_get)\n            resolved_batch.extend(resolved_from_refs)\n\n        actual_data_batch = resolved_batch\n        total_retrieved_samples += len(actual_data_batch)\n\n    consumer_end_time = time.perf_counter()\n\n    total_get_time = consumer_end_time - consumer_start_time\n    get_throughput = total_retrieved_samples / total_get_time\n    avg_batch_latency = (total_get_time / num_batches_to_get) * 1000 # ms per batch\n\n    log_with_time(f\"  - Total time for consumer to fetch {total_retrieved_samples} samples: {total_get_time:.4f} seconds\")\n    log_with_time(f\"  - Consumer Throughput: {get_throughput:.2f} samples/sec\")\n    log_with_time(f\"  - Average Batch Latency: {avg_batch_latency:.4f} ms/batch\")\n    \n    assert total_retrieved_samples == num_batches_to_get * TRAINER_BATCH_SIZE, \"Mismatch in retrieved sample count!\"\n    log_with_time(\"✅ Data integrity check passed for consumer results.\")\n\n    # 4. Clean up resources\n    log_with_time(\"\\n\\n🧹 Cleaning up resources...\")\n    # We cannot access the data_buffers directly, but we can manage them indirectly\n    # through the coordinator or directly via ray.kill.\n    # For simplicity, we only kill the coordinator as it's the only handle we have.\n    ray.kill(coordinator, no_restart=True)\n    # Note: Placement group and detached actors might need manual cleanup if not managed properly.\n    # For this test, shutting down ray is sufficient.\n    ray.shutdown()\n    log_with_time(\"=\" * 80)\n    log_with_time(\"               Benchmark Finished\")\n    log_with_time(\"=\" * 80)\n\n\nif __name__ == \"__main__\":\n    asyncio.run(main())\n"
  },
  {
    "path": "tests/data_buffer/test_data_buffer.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\nimport asyncio\nimport ray\nimport torch\nimport uuid\nfrom tensordict import TensorDict\nfrom typing import Dict, Any, List\n\n# Imports from the refactored implementation\nfrom siirl.data_coordinator.data_buffer import init_data_coordinator\nfrom siirl.data_coordinator.sample import SampleInfo\n\n\nclass TestDataCoordinator(unittest.IsolatedAsyncioTestCase):\n    \"\"\"\n    Unit tests for the new DataCoordinator/DataBuffer architecture.\n    \"\"\"\n\n    @classmethod\n    def setUpClass(cls):\n        if not ray.is_initialized():\n            ray.init(num_cpus=4, ignore_reinit_error=True, logging_level=\"error\")\n\n    @classmethod\n    def tearDownClass(cls):\n        if ray.is_initialized():\n            ray.shutdown()\n\n    async def asyncSetUp(self):\n        \"\"\"Create a new, clean DataCoordinator system for each test.\"\"\"\n        # Use force_local=True for single-node unit testing\n        self.coordinator = init_data_coordinator(num_buffers=2, force_local=True)\n        # Ensure the actor has started and is ready\n        self.assertEqual(await self.coordinator.get_valid_size.remote(), 0)\n\n    async def asyncTearDown(self):\n        \"\"\"Destroy the actor after each test.\"\"\"\n        # The coordinator is a detached actor, so we must manually kill it.\n        ray.kill(self.coordinator, no_restart=True)\n        # Allow some time for cleanup\n        await asyncio.sleep(0.1)\n\n    def _create_mock_sample(self, content_id: int) -> TensorDict:\n        \"\"\"Helper to create a sample with identifiable content.\"\"\"\n        return TensorDict({\"data\": torch.tensor([[content_id]])}, batch_size=[1])\n\n    def _create_mock_sample_info(self, tokens: int, group: int = 0) -> SampleInfo:\n        \"\"\"Helper to create a SampleInfo object.\"\"\"\n        return SampleInfo(\n            agent_group=group,\n            sum_tokens=tokens,\n            prompt_length=tokens,\n            response_length=0,\n            uid=uuid.uuid4().int\n        )\n\n    # === Test Cases ===\n\n    async def test_put_increases_size(self):\n        \"\"\"Test that calling `put` increases the coordinator's queue size.\"\"\"\n        initial_size = await self.coordinator.get_valid_size.remote()\n        self.assertEqual(initial_size, 0)\n\n        sample_ref = ray.put(self._create_mock_sample(1))\n        sample_info = self._create_mock_sample_info(tokens=128)\n        \n        await self.coordinator.put.remote(sample_info, sample_ref)\n        \n        new_size = await self.coordinator.get_valid_size.remote()\n        self.assertEqual(new_size, 1)\n\n    async def test_get_batch_simple(self):\n        \"\"\"Test basic `get_batch` functionality and data integrity.\"\"\"\n        # 1. Put a sample\n        sample_data = self._create_mock_sample(101)\n        sample_ref = ray.put(sample_data)\n        sample_info = self._create_mock_sample_info(tokens=128)\n        await self.coordinator.put.remote(sample_info, sample_ref)\n        self.assertEqual(await self.coordinator.get_valid_size.remote(), 1)\n\n        # 2. Get a batch\n        batch_refs_or_values = await self.coordinator.get_batch.remote(batch_size=1)\n        self.assertEqual(len(batch_refs_or_values), 1)\n        \n        # 3. Verify queue size decreases\n        self.assertEqual(await self.coordinator.get_valid_size.remote(), 0)\n\n        # 4. Verify data integrity\n        retrieved_item = batch_refs_or_values[0]\n        if isinstance(retrieved_item, ray.ObjectRef):\n            retrieved_item = ray.get(retrieved_item)\n        \n        self.assertTrue(torch.equal(retrieved_item.get(\"data\"), sample_data.get(\"data\")))\n\n    async def test_get_batch_insufficient_samples(self):\n        \"\"\"Test that `get_batch` returns an empty list when samples are insufficient.\"\"\"\n        # Queue is empty\n        batch = await self.coordinator.get_batch.remote(batch_size=1)\n        self.assertEqual(len(batch), 0)\n\n        # Queue has 1, but we request 2\n        sample_ref = ray.put(self._create_mock_sample(1))\n        sample_info = self._create_mock_sample_info(tokens=128)\n        await self.coordinator.put.remote(sample_info, sample_ref)\n        \n        batch = await self.coordinator.get_batch.remote(batch_size=2)\n        self.assertEqual(len(batch), 0)\n        # Ensure the queue was not modified\n        self.assertEqual(await self.coordinator.get_valid_size.remote(), 1)\n\n    async def test_get_batch_fifo_order(self):\n        \"\"\"Test that `get_batch` respects FIFO order without a filter.\"\"\"\n        # Put 3 samples with identifiable content IDs\n        samples_put = [self._create_mock_sample(i) for i in [10, 20, 30]]\n        for sample in samples_put:\n            sample_ref = ray.put(sample)\n            sample_info = self._create_mock_sample_info(tokens=128)\n            await self.coordinator.put.remote(sample_info, sample_ref)\n\n        # Get a batch of 2\n        batch_refs_or_values = await self.coordinator.get_batch.remote(batch_size=2)\n        retrieved_data = [ray.get(item) if isinstance(item, ray.ObjectRef) else item for item in batch_refs_or_values]\n        \n        # Verify the first two samples were returned\n        self.assertTrue(torch.equal(retrieved_data[0].get(\"data\"), samples_put[0].get(\"data\")))\n        self.assertTrue(torch.equal(retrieved_data[1].get(\"data\"), samples_put[1].get(\"data\")))\n        self.assertEqual(await self.coordinator.get_valid_size.remote(), 1)\n\n    async def test_get_batch_with_filter(self):\n        \"\"\"Test that `get_batch` correctly applies the filter_plugin.\"\"\"\n        # Put 4 samples with different token counts\n        sample_infos = [\n            self._create_mock_sample_info(tokens=100), # Should be filtered out\n            self._create_mock_sample_info(tokens=600), # Should be selected\n            self._create_mock_sample_info(tokens=200), # Should be filtered out\n            self._create_mock_sample_info(tokens=700), # Should be selected\n        ]\n        samples_put = [self._create_mock_sample(info.sum_tokens) for info in sample_infos]\n\n        for info, data in zip(sample_infos, samples_put):\n            sample_ref = ray.put(data)\n            await self.coordinator.put.remote(info, sample_ref)\n        \n        self.assertEqual(await self.coordinator.get_valid_size.remote(), 4)\n        \n        # Define a filter that only accepts samples with >= 512 tokens\n        def long_sample_filter(sample_info: SampleInfo) -> bool:\n            return sample_info.sum_tokens >= 512\n\n        # Request a batch of 2 with the filter\n        batch_refs_or_values = await self.coordinator.get_batch.remote(\n            batch_size=2, filter_plugin=long_sample_filter\n        )\n        self.assertEqual(len(batch_refs_or_values), 2)\n        \n        # Verify the correct samples were returned (600 and 700)\n        retrieved_data = [ray.get(item) if isinstance(item, ray.ObjectRef) else item for item in batch_refs_or_values]\n        retrieved_tokens = sorted([item.get(\"data\").item() for item in retrieved_data])\n        self.assertEqual(retrieved_tokens, [600, 700])\n\n        # Verify that the filtered-out samples remain in the queue\n        self.assertEqual(await self.coordinator.get_valid_size.remote(), 2)\n\n\n    async def test_get_batch_with_node_affinity_filter(self):\n        \"\"\"\n        Test using the filter_plugin to achieve node-affinity scheduling.\n        This simulates a trainer on 'node_A' preferentially pulling data\n        that was also produced on 'node_A'.\n        \"\"\"\n        # 1. Simulate data coming from two different nodes\n        # Samples from node_A\n        for i in range(3):\n            info = self._create_mock_sample_info(tokens=128, group=i)\n            info.node_id = \"node_A\" # Manually set node_id for testing\n            data = self._create_mock_sample(content_id=100 + i)\n            await self.coordinator.put.remote(info, ray.put(data))\n\n        # Samples from node_B\n        for i in range(2):\n            info = self._create_mock_sample_info(tokens=256, group=i)\n            info.node_id = \"node_B\" # Manually set node_id for testing\n            data = self._create_mock_sample(content_id=200 + i)\n            await self.coordinator.put.remote(info, ray.put(data))\n\n        self.assertEqual(await self.coordinator.get_valid_size.remote(), 5)\n\n        # 2. Create a filter factory to generate an affinity filter\n        # This closure captures the desired local_node_id.\n        def create_affinity_filter(local_node_id: str):\n            def affinity_filter(sample_info: SampleInfo) -> bool:\n                return sample_info.node_id == local_node_id\n            return affinity_filter\n\n        # 3. Simulate a Trainer on 'node_A' pulling its local data\n        trainer_on_node_a_filter = create_affinity_filter(\"node_A\")\n        \n        local_batch = await self.coordinator.get_batch.remote(\n            batch_size=3, filter_plugin=trainer_on_node_a_filter\n        )\n        self.assertEqual(len(local_batch), 3)\n\n        # Verify that we got all the data from node_A\n        retrieved_data = ray.get(local_batch)\n        retrieved_ids = sorted([d.get(\"data\").item() for d in retrieved_data])\n        self.assertEqual(retrieved_ids, [100, 101, 102])\n        \n        # 4. The coordinator should now only contain data from node_B\n        self.assertEqual(await self.coordinator.get_valid_size.remote(), 2)\n\n        # 5. Now, the trainer on 'node_A' can fetch remote data if needed\n        # (by inverting the filter or using a different one)\n        trainer_on_node_b_filter = create_affinity_filter(\"node_B\")\n        remote_batch = await self.coordinator.get_batch.remote(\n            batch_size=2, filter_plugin=trainer_on_node_b_filter\n        )\n        self.assertEqual(len(remote_batch), 2)\n        retrieved_data = ray.get(remote_batch)\n        retrieved_ids = sorted([d.get(\"data\").item() for d in retrieved_data])\n        self.assertEqual(retrieved_ids, [200, 201])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/scheduler/test_process_group_manager.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport unittest\nimport collections\nfrom typing import List, Optional, Set\nfrom siirl.scheduler import ProcessGroupManager\nfrom siirl.workers.dag import NodeType, Node, NodeRole, TaskGraph\n\n\nclass TestProcessGroupManager(unittest.TestCase):\n    def _create_task_graph(self, graph_id: str, node_specs: List[tuple[str, NodeType, NodeRole, Optional[List[str]]]]) -> TaskGraph:\n        \"\"\"\n        Helper to create TaskGraph with actual Node objects, including dependencies.\n        node_specs: List of (node_id, node_type, node_role, list_of_dependency_ids)\n        \"\"\"\n        tg = TaskGraph(graph_id=graph_id)  #\n        nodes_to_add = []\n        for node_id_spec, node_type_spec, node_role_spec, deps_spec in node_specs:\n            # Node class validation for NodeRole based on NodeType\n            # For simplicity, ensure roles are compatible or use NodeRole.DEFAULT for non-model types.\n            current_role = node_role_spec\n            if node_type_spec not in [NodeType.COMPUTE, NodeType.MODEL_TRAIN, NodeType.MODEL_INFERENCE] and current_role != NodeRole.DEFAULT:  #\n                current_role = NodeRole.DEFAULT  #\n\n            nodes_to_add.append(Node(node_id=node_id_spec, node_type=node_type_spec, node_role=current_role, dependencies=deps_spec))  #\n        if nodes_to_add:\n            tg.add_nodes(nodes_to_add)  #\n        # PGM doesn't rely on built adjacency lists, but it's good practice for a valid TG\n        tg.build_adjacency_lists()  #\n        return tg\n\n    def test_initialization_default_relevant_types(self):\n        pgm = ProcessGroupManager(total_num_workers=1, ranks_taskgraph_mapping={})  #\n        self.assertEqual(pgm.relevant_node_types, {NodeType.MODEL_INFERENCE, NodeType.MODEL_TRAIN})  #\n\n    def test_initialization_custom_relevant_types(self):\n        custom_types = {NodeType.COMPUTE, NodeType.DATA_LOAD}  #\n        pgm = ProcessGroupManager(total_num_workers=1, ranks_taskgraph_mapping={}, relevant_node_types_param=custom_types)  #\n        self.assertEqual(pgm.relevant_node_types, custom_types)  #\n\n    def test_initialization_invalid_custom_relevant_types(self):\n        with self.assertRaises(ValueError):  #\n            ProcessGroupManager(total_num_workers=1, ranks_taskgraph_mapping={}, relevant_node_types_param={\"not_a_set\"})  # type: ignore\n        with self.assertRaises(ValueError):  #\n            ProcessGroupManager(total_num_workers=1, ranks_taskgraph_mapping={}, relevant_node_types_param={NodeType.COMPUTE, \"not_a_node_type\"})  # type: ignore\n\n    def test_default_filtering_model_nodes_only(self):\n        \"\"\"Tests that by default, only MODEL_INFERENCE and MODEL_TRAIN nodes are processed.\"\"\"\n        # N_train depends on N_compute, but N_compute is ignored.\n        # N_infer also depends on N_compute.\n        tg1 = self._create_task_graph(\n            \"TG1\",\n            [\n                (\"N_compute\", NodeType.COMPUTE, NodeRole.DEFAULT, None),  # Irrelevant by default\n                (\"N_train\", NodeType.MODEL_TRAIN, NodeRole.DEFAULT, [\"N_compute\"]),  # Relevant\n                (\"N_infer\", NodeType.MODEL_INFERENCE, NodeRole.DEFAULT, [\"N_compute\"]),  # Relevant\n            ],\n        )\n        mapping = {0: tg1, 1: tg1}  # N_train, N_infer from TG1 on ranks 0, 1\n        pgm = ProcessGroupManager(total_num_workers=2, ranks_taskgraph_mapping=mapping)  #\n\n        # N_compute is ignored, so its dependencies don't affect PGM's view of N_train/N_infer.\n        self.assertEqual(pgm.node_ranks_mapping, {\"N_train\": [0, 1], \"N_infer\": [0, 1]})  #\n        self.assertNotIn(\"N_compute\", pgm.node_ranks_mapping)  #\n        self.assertEqual(pgm.process_group_spec, {\"process_group_1\": [0, 1]})  #\n        self.assertEqual(\n            pgm.node_process_group_mapping,  #\n            {\"N_train\": \"process_group_1\", \"N_infer\": \"process_group_1\"},\n        )\n\n        expected_node_type_pg_map = {  #\n            NodeType.MODEL_TRAIN.value: {\"process_group_1\"},  #\n            NodeType.MODEL_INFERENCE.value: {\"process_group_1\"},  #\n        }\n        self.assertEqual(pgm.node_type_process_group_mapping, expected_node_type_pg_map)  #\n        self.assertNotIn(NodeType.COMPUTE.value, pgm.node_type_process_group_mapping)  #\n\n        self.assertEqual(pgm.subgraph_node_type_pg_mapping[\"TG1\"][NodeType.MODEL_TRAIN.value], {\"process_group_1\"})  #\n        self.assertEqual(pgm.subgraph_node_type_pg_mapping[\"TG1\"][NodeType.MODEL_INFERENCE.value], {\"process_group_1\"})  #\n        self.assertNotIn(NodeType.COMPUTE.value, pgm.subgraph_node_type_pg_mapping[\"TG1\"])  #\n\n        # Test getters\n        self.assertEqual(pgm.get_process_groups_for_node_type(NodeType.MODEL_TRAIN.value), {\"process_group_1\"})  #\n        self.assertEqual(pgm.get_process_groups_for_node_type(NodeType.COMPUTE.value), set())  # Not relevant\n        self.assertEqual(pgm.get_process_group_for_node_type_in_subgraph(\"TG1\", NodeType.MODEL_INFERENCE.value), {\"process_group_1\"})  #\n        self.assertEqual(pgm.get_process_group_for_node_type_in_subgraph(\"TG1\", NodeType.COMPUTE.value), set())  #\n\n    def test_custom_filtering_compute_and_load_only(self):\n        \"\"\"Tests filtering with custom relevant types (e.g., COMPUTE and DATA_LOAD).\"\"\"\n        relevant_types = {NodeType.COMPUTE, NodeType.DATA_LOAD}  #\n        # N_compute depends on N_load. N_train depends on N_compute but is ignored.\n        tg1 = self._create_task_graph(\n            \"TG1\",\n            [\n                (\"N_load\", NodeType.DATA_LOAD, NodeRole.DEFAULT, None),  # Relevant\n                (\"N_compute\", NodeType.COMPUTE, NodeRole.DEFAULT, [\"N_load\"]),  # Relevant\n                (\"N_train\", NodeType.MODEL_TRAIN, NodeRole.DEFAULT, [\"N_compute\"]),  # Ignored\n            ],\n        )\n        mapping = {0: tg1, 1: tg1}  # N_load, N_compute from TG1 on ranks 0, 1\n        pgm = ProcessGroupManager(total_num_workers=2, ranks_taskgraph_mapping=mapping, relevant_node_types_param=relevant_types)  #\n\n        self.assertEqual(pgm.node_ranks_mapping, {\"N_load\": [0, 1], \"N_compute\": [0, 1]})  #\n        self.assertNotIn(\"N_train\", pgm.node_ranks_mapping)  #\n        self.assertEqual(pgm.process_group_spec, {\"process_group_1\": [0, 1]})  #\n        self.assertEqual(\n            pgm.node_process_group_mapping,  #\n            {\"N_load\": \"process_group_1\", \"N_compute\": \"process_group_1\"},\n        )\n\n        expected_node_type_pg_map = {  #\n            NodeType.DATA_LOAD.value: {\"process_group_1\"},  #\n            NodeType.COMPUTE.value: {\"process_group_1\"},  #\n        }\n        self.assertEqual(pgm.node_type_process_group_mapping, expected_node_type_pg_map)  #\n        self.assertEqual(pgm.subgraph_node_type_pg_mapping[\"TG1\"], expected_node_type_pg_map)  #\n\n        # Test getters\n        self.assertEqual(pgm.get_process_groups_for_node_type(NodeType.COMPUTE.value), {\"process_group_1\"})  #\n        self.assertEqual(pgm.get_process_groups_for_node_type(NodeType.MODEL_TRAIN.value), set())  # Not relevant\n        self.assertEqual(pgm.get_process_group_for_node_type_in_subgraph(\"TG1\", NodeType.DATA_LOAD.value), {\"process_group_1\"})  #\n        self.assertEqual(pgm.get_process_group_for_node_type_in_subgraph(\"TG1\", NodeType.MODEL_TRAIN.value), set())  #\n\n    def test_no_relevant_nodes_in_graph_custom_filter(self):\n        \"\"\"Tests behavior when a graph contains no nodes of the custom configured relevant types.\"\"\"\n        relevant_types = {NodeType.CUSTOM}  # Only interested in CUSTOM\n        tg1 = self._create_task_graph(\n            \"TG1\",\n            [\n                (\"N_train\", NodeType.MODEL_TRAIN, NodeRole.DEFAULT, None),  #\n                (\"N_compute\", NodeType.COMPUTE, NodeRole.DEFAULT, None),  #\n            ],\n        )  # No CUSTOM nodes\n        mapping = {0: tg1}\n        pgm = ProcessGroupManager(total_num_workers=1, ranks_taskgraph_mapping=mapping, relevant_node_types_param=relevant_types)  #\n\n        self.assertEqual(pgm.node_ranks_mapping, {})  #\n        self.assertEqual(pgm.process_group_spec, {})  #\n        self.assertEqual(pgm.node_process_group_mapping, {})  #\n        self.assertEqual(pgm.node_type_process_group_mapping, collections.defaultdict(set))  #\n        self.assertEqual(pgm.subgraph_node_type_pg_mapping, collections.defaultdict(lambda: collections.defaultdict(set)))  #\n\n    def test_empty_relevant_types_set_no_nodes_processed(self):\n        \"\"\"Tests behavior with an empty set of relevant_node_types (no nodes should be processed).\"\"\"\n        relevant_types: Set[NodeType] = set()  # Empty set\n        tg1 = self._create_task_graph(\n            \"TG1\",\n            [\n                (\"N_train\", NodeType.MODEL_TRAIN, NodeRole.DEFAULT, None),  #\n                (\"N_compute\", NodeType.COMPUTE, NodeRole.DEFAULT, None),  #\n            ],\n        )\n        mapping = {0: tg1}\n        pgm = ProcessGroupManager(total_num_workers=1, ranks_taskgraph_mapping=mapping, relevant_node_types_param=relevant_types)  #\n\n        self.assertEqual(pgm.node_ranks_mapping, {})  #\n        self.assertEqual(pgm.process_group_spec, {})  #\n        self.assertEqual(pgm.node_type_process_group_mapping, collections.defaultdict(set))  #\n        self.assertEqual(pgm.subgraph_node_type_pg_mapping, collections.defaultdict(lambda: collections.defaultdict(set)))  #\n\n    def test_complex_scenario_with_dependencies_and_default_filtering(self):\n        \"\"\"A more complex setup with dependencies using default filtering.\"\"\"\n        # TG1: M1_train depends on C1_comp (ignored). M1_infer also depends on C1_comp (ignored).\n        #      So M1_train and M1_infer are effectively roots for PGM.\n        tg1 = self._create_task_graph(\n            \"TG1_MODELS\",\n            [\n                (\"C1_comp\", NodeType.COMPUTE, NodeRole.DEFAULT, None),  # Ignored\n                (\"M1_train\", NodeType.MODEL_TRAIN, NodeRole.DEFAULT, [\"C1_comp\"]),  # Relevant\n                (\"M1_infer\", NodeType.MODEL_INFERENCE, NodeRole.DEFAULT, [\"C1_comp\"]),  # Relevant\n            ],\n        )\n        # TG2: M2_train depends on D2_load (ignored). So M2_train is effectively a root.\n        tg2 = self._create_task_graph(\n            \"TG2_MODELS\",\n            [\n                (\"D2_load\", NodeType.DATA_LOAD, NodeRole.DEFAULT, None),  # Ignored\n                (\"M2_train\", NodeType.MODEL_TRAIN, NodeRole.DEFAULT, [\"D2_load\"]),  # Relevant\n            ],\n        )\n        # TG3: All nodes ignored. C3_comp depends on D3_load.\n        tg3 = self._create_task_graph(\n            \"TG3_OTHER\",\n            [\n                (\"D3_load\", NodeType.DATA_LOAD, NodeRole.DEFAULT, None),  # Ignored\n                (\"C3_comp\", NodeType.COMPUTE, NodeRole.DEFAULT, [\"D3_load\"]),  # Ignored\n            ],\n        )\n        # TG4: Relevant M4_train depends on irrelevant C4_comp which depends on relevant M4_infer.\n        # PGM will see M4_train and M4_infer.\n        tg4 = self._create_task_graph(\n            \"TG4_MIXED_DEP\",\n            [\n                (\"M4_infer\", NodeType.MODEL_INFERENCE, NodeRole.DEFAULT, None),  # Relevant\n                (\"C4_comp\", NodeType.COMPUTE, NodeRole.DEFAULT, [\"M4_infer\"]),  # Ignored\n                (\"M4_train\", NodeType.MODEL_TRAIN, NodeRole.DEFAULT, [\"C4_comp\"]),  # Relevant\n            ],\n        )\n\n        mapping = {\n            0: tg1,\n            1: tg1,  # TG1: M1_train, M1_infer on ranks [0,1] -> PG1\n            2: tg2,\n            3: tg2,  # TG2: M2_train on ranks [2,3] -> PG2\n            4: tg3,  # TG3: No relevant nodes.\n            5: tg4,\n            6: tg4,  # TG4: M4_infer, M4_train on ranks [5,6] -> PG3\n        }\n        pgm = ProcessGroupManager(total_num_workers=7, ranks_taskgraph_mapping=mapping)  #\n\n        self.assertEqual(\n            pgm.node_ranks_mapping,\n            {  #\n                \"M1_train\": [0, 1],\n                \"M1_infer\": [0, 1],\n                \"M2_train\": [2, 3],\n                \"M4_infer\": [5, 6],\n                \"M4_train\": [5, 6],\n            },\n        )\n        # Ensure ignored nodes are not present\n        for ignored_node_id in [\"C1_comp\", \"D2_load\", \"D3_load\", \"C3_comp\", \"C4_comp\"]:\n            self.assertNotIn(ignored_node_id, pgm.node_ranks_mapping)  #\n\n        # Rank tuples: (0,1), (2,3), (5,6). Sorted: (0,1), (2,3), (5,6).\n        self.assertEqual(\n            pgm.process_group_spec,\n            {  #\n                \"process_group_1\": [0, 1],\n                \"process_group_2\": [2, 3],\n                \"process_group_3\": [5, 6],\n            },\n        )\n\n        self.assertEqual(\n            pgm.node_process_group_mapping,\n            {  #\n                \"M1_train\": \"process_group_1\",\n                \"M1_infer\": \"process_group_1\",\n                \"M2_train\": \"process_group_2\",\n                \"M4_infer\": \"process_group_3\",\n                \"M4_train\": \"process_group_3\",\n            },\n        )\n\n        expected_node_type_pg_map = {  #\n            NodeType.MODEL_TRAIN.value: {\"process_group_1\", \"process_group_2\", \"process_group_3\"},  #\n            NodeType.MODEL_INFERENCE.value: {\"process_group_1\", \"process_group_3\"},  #\n        }\n        self.assertEqual(pgm.node_type_process_group_mapping, expected_node_type_pg_map)  #\n\n        # Check subgraph mappings\n        self.assertEqual(pgm.subgraph_node_type_pg_mapping[\"TG1_MODELS\"][NodeType.MODEL_TRAIN.value], {\"process_group_1\"})  #\n        self.assertEqual(pgm.subgraph_node_type_pg_mapping[\"TG1_MODELS\"][NodeType.MODEL_INFERENCE.value], {\"process_group_1\"})  #\n        self.assertNotIn(NodeType.COMPUTE.value, pgm.subgraph_node_type_pg_mapping[\"TG1_MODELS\"])  #\n\n        self.assertEqual(pgm.subgraph_node_type_pg_mapping[\"TG2_MODELS\"][NodeType.MODEL_TRAIN.value], {\"process_group_2\"})  #\n\n        self.assertEqual(pgm.subgraph_node_type_pg_mapping[\"TG4_MIXED_DEP\"][NodeType.MODEL_INFERENCE.value], {\"process_group_3\"})  #\n        self.assertEqual(pgm.subgraph_node_type_pg_mapping[\"TG4_MIXED_DEP\"][NodeType.MODEL_TRAIN.value], {\"process_group_3\"})  #\n\n        self.assertNotIn(\"TG3_OTHER\", pgm.subgraph_node_type_pg_mapping)  # No relevant nodes\n\n    def test_highly_complex_scenario_shared_nodes_16_workers(self):\n        \"\"\"\n        Test: A more complex scenario with 16 workers, multiple TaskGraphs,\n        nodes with shared IDs across these TaskGraphs, and intricate rank overlaps.\n        Uses default relevant types (MODEL_TRAIN, MODEL_INFERENCE).\n        \"\"\"\n        # Define Node Specs: (id, type, role, dependencies)\n\n        # --- Shared Nodes ---\n        # S_Train_1 will appear in TG_Alpha and TG_Beta\n        s_train_1_spec = (\"S_Train_1\", NodeType.MODEL_TRAIN, NodeRole.DEFAULT, None)\n        # S_Infer_1 will appear in TG_Beta and TG_Gamma\n        s_infer_1_spec = (\"S_Infer_1\", NodeType.MODEL_INFERENCE, NodeRole.DEFAULT, None)\n        # S_Train_2 will appear in TG_Gamma and TG_Delta\n        s_train_2_spec = (\"S_Train_2\", NodeType.MODEL_TRAIN, NodeRole.DEFAULT, None)\n\n        # --- TaskGraph Alpha Specific Nodes ---\n        tga_mt_1_spec = (\"TGA_MT_1\", NodeType.MODEL_TRAIN, NodeRole.DEFAULT, [\"S_Train_1\"])\n        tga_c1_spec = (\"TGA_C1\", NodeType.COMPUTE, NodeRole.DEFAULT, [\"TGA_MT_1\"])  # Irrelevant\n\n        # --- TaskGraph Beta Specific Nodes ---\n        tgb_mi_1_spec = (\"TGB_MI_1\", NodeType.MODEL_INFERENCE, NodeRole.DEFAULT, [\"S_Train_1\"])\n        tgb_c1_spec = (\"TGB_C1\", NodeType.COMPUTE, NodeRole.DEFAULT, [\"S_Infer_1\"])  # Irrelevant\n\n        # --- TaskGraph Gamma Specific Nodes ---\n        tgc_mt_1_spec = (\"TGC_MT_1\", NodeType.MODEL_TRAIN, NodeRole.DEFAULT, [\"S_Infer_1\"])\n        tgc_dl1_spec = (\"TGC_DL1\", NodeType.DATA_LOAD, NodeRole.DEFAULT, [\"S_Train_2\"])  # Irrelevant\n\n        # --- TaskGraph Delta Specific Nodes ---\n        tgd_mi_1_spec = (\"TGD_MI_1\", NodeType.MODEL_INFERENCE, NodeRole.DEFAULT, [\"S_Train_2\"])\n        tgd_c1_spec = (\"TGD_C1\", NodeType.COMPUTE, NodeRole.DEFAULT, [\"TGD_MI_1\"])  # Irrelevant\n\n        # --- TaskGraph Epsilon Nodes (all irrelevant) ---\n        tge_c1_spec = (\"TGE_C1\", NodeType.COMPUTE, NodeRole.DEFAULT, None)\n        tge_dl1_spec = (\"TGE_DL1\", NodeType.DATA_LOAD, NodeRole.DEFAULT, [\"TGE_C1\"])\n\n        # --- TaskGraph Zeta Node (isolated relevant node) ---\n        tgz_mt_1_spec = (\"TGZ_MT_1\", NodeType.MODEL_TRAIN, NodeRole.DEFAULT, None)\n\n        # Create TaskGraph instances\n        tg_alpha = self._create_task_graph(\"TG_Alpha\", [s_train_1_spec, tga_mt_1_spec, tga_c1_spec])\n        tg_beta = self._create_task_graph(\"TG_Beta\", [s_train_1_spec, s_infer_1_spec, tgb_mi_1_spec, tgb_c1_spec])\n        tg_gamma = self._create_task_graph(\"TG_Gamma\", [s_infer_1_spec, s_train_2_spec, tgc_mt_1_spec, tgc_dl1_spec])\n        tg_delta = self._create_task_graph(\"TG_Delta\", [s_train_2_spec, tgd_mi_1_spec, tgd_c1_spec])\n        tg_epsilon = self._create_task_graph(\"TG_Epsilon\", [tge_c1_spec, tge_dl1_spec])\n        tg_zeta = self._create_task_graph(\"TG_Zeta\", [tgz_mt_1_spec])\n\n        # Define rank assignments for each TaskGraph for PGM input\n        # PGM's `_collect_initial_topology_info` iterates `ranks_taskgraph_mapping.items()`.\n        # It builds `graph_id_to_ranks`. For a node_id present in multiple GIDs,\n        # `_aggregate_ranks_for_nodes` takes the union of `graph_id_to_ranks` for those GIDs.\n        final_mapping_for_pgm = {}\n        for r_val in range(16):\n            final_mapping_for_pgm[r_val] = None  # Initialize\n\n        # Assign ranks to TaskGraphs. If a rank is assigned multiple TGs, the last one wins for that rank.\n        # PGM's graph_id_to_ranks will reflect the set of ranks ultimately pointing to each GID.\n        # Ranks for TG_Alpha\n        for r_idx in {0, 1}:\n            final_mapping_for_pgm[r_idx] = tg_alpha\n        # Ranks for TG_Beta (rank 2 will be tg_beta, rank 4 tg_gamma, rank 6 tg_delta)\n        for r_idx in {2, 3, 10}:\n            final_mapping_for_pgm[r_idx] = tg_beta\n        # Ranks for TG_Gamma\n        for r_idx in {4, 5, 11}:\n            final_mapping_for_pgm[r_idx] = tg_gamma\n        # Ranks for TG_Delta\n        for r_idx in {6, 7, 12, 13}:\n            final_mapping_for_pgm[r_idx] = tg_delta\n        # Ranks for TG_Epsilon (no relevant nodes)\n        for r_idx in {8, 9}:\n            final_mapping_for_pgm[r_idx] = tg_epsilon\n        # Ranks for TG_Zeta\n        for r_idx in {14, 15}:\n            final_mapping_for_pgm[r_idx] = tg_zeta\n\n        # Update tg_alpha's mapping for rank 2 if it wasn't overwritten\n        # To achieve the intended rank aggregation, the final_mapping_for_pgm should be:\n        # TG_Alpha assigned to {0,1}\n        # TG_Beta assigned to {2,3,10}\n        # TG_Gamma assigned to {4,5,11}\n        # TG_Delta assigned to {6,7,12,13}\n        # TG_Epsilon assigned to {8,9}\n        # TG_Zeta assigned to {14,15}\n        #\n        # And Node_Shared_Train is in TG_Alpha (ID: S_Train_1) and TG_Beta (ID: S_Train_1)\n        # Node_Shared_Infer is in TG_Beta (ID: S_Infer_1) and TG_Gamma (ID: S_Infer_1)\n        # Node_Shared_Train_2 is in TG_Gamma (ID: S_Train_2) and TG_Delta (ID: S_Train_2)\n        # This setup correctly tests the union of ranks for shared node_ids.\n\n        pgm = ProcessGroupManager(total_num_workers=16, ranks_taskgraph_mapping=final_mapping_for_pgm)\n\n        # Expected Node Ranks based on `final_mapping_for_pgm` and node presence in TGs:\n        # S_Train_1: in TG_Alpha (ranks {0,1}), in TG_Beta (ranks {2,3,10}) => {0,1,2,3,10}\n        # TGA_MT_1: in TG_Alpha (ranks {0,1}) => {0,1}\n        # S_Infer_1: in TG_Beta (ranks {2,3,10}), in TG_Gamma (ranks {4,5,11}) => {2,3,4,5,10,11}\n        # TGB_MI_1: in TG_Beta (ranks {2,3,10}) => {2,3,10}\n        # S_Train_2: in TG_Gamma (ranks {4,5,11}), in TG_Delta (ranks {6,7,12,13}) => {4,5,6,7,11,12,13}\n        # TGC_MT_1: in TG_Gamma (ranks {4,5,11}) => {4,5,11}\n        # TGD_MI_1: in TG_Delta (ranks {6,7,12,13}) => {6,7,12,13}\n        # TGZ_MT_1: in TG_Zeta (on {14,15}) => {14,15}\n        self.assertEqual(\n            pgm.node_ranks_mapping,\n            {\"S_Train_1\": sorted([0, 1, 2, 3, 10]), \"TGA_MT_1\": sorted([0, 1]), \"S_Infer_1\": sorted([2, 3, 4, 5, 10, 11]), \"TGB_MI_1\": sorted([2, 3, 10]), \"S_Train_2\": sorted([4, 5, 6, 7, 11, 12, 13]), \"TGC_MT_1\": sorted([4, 5, 11]), \"TGD_MI_1\": sorted([6, 7, 12, 13]), \"TGZ_MT_1\": sorted([14, 15])},\n        )\n\n        # Corrected expected process_group_spec based on lexicographical sort of unique rank tuples\n        # Unique rank tuples:\n        # (0,1)\n        # (0,1,2,3,10)\n        # (2,3,4,5,10,11)\n        # (2,3,10)\n        # (4,5,6,7,11,12,13)\n        # (4,5,11)\n        # (6,7,12,13)\n        # (14,15)\n        # Sorted list of these tuples gives PG names:\n        # PG1: (0,1)\n        # PG2: (0,1,2,3,10)\n        # PG3: (2,3,4,5,10,11)  <-- Corrected from previous manual sort error\n        # PG4: (2,3,10)         <-- Corrected\n        # PG5: (4,5,6,7,11,12,13) <-- Corrected\n        # PG6: (4,5,11)         <-- Corrected\n        # PG7: (6,7,12,13)\n        # PG8: (14,15)\n        self.assertEqual(\n            pgm.process_group_spec,\n            {\n                \"process_group_1\": sorted([0, 1]),\n                \"process_group_2\": sorted([0, 1, 2, 3, 10]),\n                \"process_group_3\": sorted([2, 3, 4, 5, 10, 11]),  # Corresponds to S_Infer_1\n                \"process_group_4\": sorted([2, 3, 10]),  # Corresponds to TGB_MI_1\n                \"process_group_5\": sorted([4, 5, 6, 7, 11, 12, 13]),  # Corresponds to S_Train_2\n                \"process_group_6\": sorted([4, 5, 11]),  # Corresponds to TGC_MT_1\n                \"process_group_7\": sorted([6, 7, 12, 13]),  # Corresponds to TGD_MI_1\n                \"process_group_8\": sorted([14, 15]),  # Corresponds to TGZ_MT_1\n            },\n        )\n\n        # Corrected expected node_process_group_mapping\n        self.assertEqual(\n            pgm.node_process_group_mapping,\n            {\n                \"TGA_MT_1\": \"process_group_1\",  # Ranks (0,1)\n                \"S_Train_1\": \"process_group_2\",  # Ranks (0,1,2,3,10)\n                \"S_Infer_1\": \"process_group_3\",  # Ranks (2,3,4,5,10,11)\n                \"TGB_MI_1\": \"process_group_4\",  # Ranks (2,3,10)\n                \"S_Train_2\": \"process_group_5\",  # Ranks (4,5,6,7,11,12,13)\n                \"TGC_MT_1\": \"process_group_6\",  # Ranks (4,5,11)\n                \"TGD_MI_1\": \"process_group_7\",  # Ranks (6,7,12,13)\n                \"TGZ_MT_1\": \"process_group_8\",  # Ranks (14,15)\n            },\n        )\n\n        # Corrected expected node_type_process_group_mapping\n        # MODEL_TRAIN: TGA_MT_1(PG1), S_Train_1(PG2), S_Train_2(PG5), TGC_MT_1(PG6), TGZ_MT_1(PG8)\n        # MODEL_INFERENCE: S_Infer_1(PG3), TGB_MI_1(PG4), TGD_MI_1(PG7)\n        expected_type_map = {NodeType.MODEL_TRAIN.value: {\"process_group_1\", \"process_group_2\", \"process_group_5\", \"process_group_6\", \"process_group_8\"}, NodeType.MODEL_INFERENCE.value: {\"process_group_3\", \"process_group_4\", \"process_group_7\"}}\n        self.assertEqual(pgm.node_type_process_group_mapping, expected_type_map)\n\n        # Corrected expected subgraph_node_type_pg_mapping\n        self.assertEqual(\n            pgm.subgraph_node_type_pg_mapping[\"TG_Alpha\"],\n            {\n                NodeType.MODEL_TRAIN.value: {\"process_group_1\", \"process_group_2\"}  # TGA_MT_1 (PG1), S_Train_1 (PG2)\n            },\n        )\n        self.assertEqual(\n            pgm.subgraph_node_type_pg_mapping[\"TG_Beta\"],\n            {\n                NodeType.MODEL_TRAIN.value: {\"process_group_2\"},  # S_Train_1 (PG2)\n                NodeType.MODEL_INFERENCE.value: {\"process_group_3\", \"process_group_4\"},  # S_Infer_1 (PG3), TGB_MI_1 (PG4)\n            },\n        )\n        self.assertEqual(\n            pgm.subgraph_node_type_pg_mapping[\"TG_Gamma\"],\n            {\n                NodeType.MODEL_INFERENCE.value: {\"process_group_3\"},  # S_Infer_1 (PG3)\n                NodeType.MODEL_TRAIN.value: {\"process_group_5\", \"process_group_6\"},  # S_Train_2 (PG5), TGC_MT_1 (PG6)\n            },\n        )\n        self.assertEqual(\n            pgm.subgraph_node_type_pg_mapping[\"TG_Delta\"],\n            {\n                NodeType.MODEL_TRAIN.value: {\"process_group_5\"},  # S_Train_2 (PG5)\n                NodeType.MODEL_INFERENCE.value: {\"process_group_7\"},  # TGD_MI_1 (PG7)\n            },\n        )\n        self.assertNotIn(\"TG_Epsilon\", pgm.subgraph_node_type_pg_mapping)\n        self.assertEqual(\n            pgm.subgraph_node_type_pg_mapping[\"TG_Zeta\"],\n            {\n                NodeType.MODEL_TRAIN.value: {\"process_group_8\"}  # TGZ_MT_1 (PG8)\n            },\n        )\n\n        # Verify getters for some specific cases\n        self.assertEqual(pgm.get_node_assignment(\"S_Train_1\"), {\"ranks\": sorted([0, 1, 2, 3, 10]), \"process_group_name\": \"process_group_2\"})\n        self.assertEqual(pgm.get_process_groups_for_node_type(NodeType.MODEL_TRAIN.value), {\"process_group_1\", \"process_group_2\", \"process_group_5\", \"process_group_6\", \"process_group_8\"})\n        self.assertEqual(pgm.get_process_group_for_node_type_in_subgraph(\"TG_Beta\", NodeType.MODEL_INFERENCE.value), {\"process_group_3\", \"process_group_4\"})\n        self.assertEqual(pgm.get_process_group_for_node_type_in_subgraph(\"TG_Epsilon\", NodeType.COMPUTE.value), set())\n\n\nif __name__ == \"__main__\":\n    unittest.main(argv=[\"first-arg-is-ignored\"], exit=False, verbosity=2)\n"
  },
  {
    "path": "tests/scheduler/test_task_scheduler.py",
    "content": "# Copyright 2025, Shanghai Innovation Institute.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# test_task_scheduler.py\nimport unittest\nimport collections\nfrom typing import List, Dict, Optional\nfrom siirl.workers.dag.node import Node, NodeType\nfrom siirl.workers.dag.task_graph import TaskGraph\nfrom siirl.scheduler.task_scheduler import TaskScheduler\n\n\n# Helper function to create dummy TaskGraph objects for testing\ndef create_test_graph(graph_id: str, num_nodes: int, model_params: any = 0.0, dependencies_map: Optional[Dict[int, List[int]]] = None) -> TaskGraph:\n    \"\"\"\n    Creates a TaskGraph for testing purposes.\n    Args:\n        graph_id (str): The ID for the graph.\n        num_nodes (int): Number of nodes to create in the graph.\n        model_params (any): Model parameters for the first node (if applicable). Can be float or string like \"10B\".\n        dependencies_map (Optional[Dict[int, List[int]]]): A map where key is node index and value is a list of dependency indices.\n                                                            Example: {1: [0]} means node 1 depends on node 0.\n    Returns:\n        TaskGraph: The created TaskGraph object.\n    \"\"\"\n    graph = TaskGraph(graph_id=graph_id)\n    nodes_to_add = []\n    for i in range(num_nodes):\n        node_type_val = NodeType.COMPUTE\n        node_config_val = {}\n        current_node_id = f\"{graph_id}_n{i}\"\n        node_deps_ids: List[str] = []\n\n        if dependencies_map and i in dependencies_map:\n            for dep_idx in dependencies_map[i]:\n                if 0 <= dep_idx < i:  # Ensure dependencies are on already defined (or to be defined earlier) nodes\n                    node_deps_ids.append(f\"{graph_id}_n{dep_idx}\")\n                else:\n                    # This case should ideally be handled by graph validation, but good to be aware\n                    pass  # Or raise an error for invalid dependency definition in test setup\n\n        # Assign model_params to the first node if provided, to make it a MODEL_TRAIN type\n        if model_params != 0.0 and i == 0:\n            node_type_val = NodeType.MODEL_TRAIN\n            node_config_val = {\"model_params\": model_params}\n\n        nodes_to_add.append(Node(node_id=current_node_id, node_type=node_type_val, config=node_config_val, dependencies=node_deps_ids))\n\n    if nodes_to_add:\n        graph.add_nodes(nodes_to_add)\n        graph.build_adjacency_lists()  # Crucial for graph operations and validation\n        is_valid, msg = graph.validate_graph()\n        if not is_valid:\n            # This helps catch issues in test graph creation itself\n            raise ValueError(f\"Test helper created an invalid graph '{graph_id}': {msg}\")\n    return graph\n\n\nclass TestTaskScheduler(unittest.TestCase):\n    \"\"\"\n    Unit tests for the TaskScheduler class.\n    \"\"\"\n\n    def setUp(self):\n        \"\"\"\n        Set up common resources for tests.\n        This method is called before each test function.\n        \"\"\"\n        # Default scheduler configuration for many tests\n        self.num_nodes_default = 2\n        self.gpus_per_node_default = 2\n        self.scheduler_default = TaskScheduler(num_physical_nodes=self.num_nodes_default, gpus_per_node=self.gpus_per_node_default)\n        # Total workers = 2 * 2 = 4\n\n    def test_scheduler_initialization(self):\n        \"\"\"\n        Test the initialization of the TaskScheduler.\n        \"\"\"\n        scheduler = TaskScheduler(num_physical_nodes=3, gpus_per_node=4)\n        self.assertEqual(scheduler.num_physical_nodes, 3)\n        self.assertEqual(scheduler.gpus_per_node, 4)\n        self.assertEqual(scheduler.num_workers, 12)  # 3 nodes * 4 GPUs/node\n        self.assertEqual(len(scheduler.worker_to_graph_assignment), 12)\n        self.assertTrue(all(assignment is None for assignment in scheduler.worker_to_graph_assignment.values()))\n\n        # Test initialization with invalid parameters\n        with self.assertRaises(ValueError):\n            TaskScheduler(num_physical_nodes=0, gpus_per_node=2)\n        with self.assertRaises(ValueError):\n            TaskScheduler(num_physical_nodes=2, gpus_per_node=0)\n\n    def test_reset_scheduler_state(self):\n        \"\"\"\n        Test the _reset_scheduler_state method.\n        \"\"\"\n        scheduler = self.scheduler_default  # Uses 2 nodes, 2 GPUs/node = 4 workers\n\n        # Simulate some assignments\n        graph1 = create_test_graph(\"g1\", 1)\n        scheduler.worker_to_graph_assignment[0] = graph1\n        scheduler.node_active_worker_count[0] = 1\n        scheduler.node_free_gpus[0] = [1]  # Worker 0 on node 0 is busy, worker 1 is free\n        scheduler.node_free_gpus[1] = [2, 3]\n\n        scheduler._reset_scheduler_state()  # Call the reset method\n\n        # Check if state is reset to initial conditions\n        self.assertEqual(len(scheduler.worker_to_graph_assignment), scheduler.num_workers)\n        self.assertTrue(all(assignment is None for assignment in scheduler.worker_to_graph_assignment.values()))\n        self.assertEqual(scheduler.node_active_worker_count, collections.defaultdict(int))\n\n        # Verify node_free_gpus is correctly re-initialized\n        expected_free_gpus = {}\n        for i in range(scheduler.num_physical_nodes):\n            expected_free_gpus[i] = list(range(i * scheduler.gpus_per_node, (i + 1) * scheduler.gpus_per_node))\n\n        self.assertEqual(dict(scheduler.node_free_gpus), expected_free_gpus)\n\n    def test_get_original_graph_id(self):\n        \"\"\"\n        Test the _get_original_graph_id helper method.\n        \"\"\"\n        graph_orig = TaskGraph(graph_id=\"original_task\")\n        graph_split1 = TaskGraph(graph_id=\"original_task_final_1\")\n        graph_split2 = TaskGraph(graph_id=\"original_task_reconv1_s0_final_2\")\n        graph_no_suffix = TaskGraph(graph_id=\"another_task\")\n\n        self.assertEqual(self.scheduler_default._get_original_graph_id(graph_orig), \"original_task\")\n        self.assertEqual(self.scheduler_default._get_original_graph_id(graph_split1), \"original_task\")\n        # Assuming the current logic splits by \"_final_\" first\n        self.assertEqual(self.scheduler_default._get_original_graph_id(graph_split2), \"original_task_reconv1_s0\")\n        self.assertEqual(self.scheduler_default._get_original_graph_id(graph_no_suffix), \"another_task\")\n\n    def test_apportion_workers_to_tasks_even_strategy(self):\n        \"\"\"\n        Test _apportion_workers_to_tasks with 'even' strategy.\n        \"\"\"\n        scheduler = TaskScheduler(num_physical_nodes=1, gpus_per_node=10)  # 10 workers\n\n        # Scenario 1: Fewer tasks than workers\n        tasks_info1 = [(create_test_graph(\"g1\", 1, model_params=100), 100.0), (create_test_graph(\"g2\", 1, model_params=50), 50.0)]  # 2 tasks\n        apportionment1 = scheduler._apportion_workers_to_tasks(tasks_info1, 10, \"even\")\n        # Expected: 10 workers / 2 tasks = 5 workers per task\n        self.assertEqual(apportionment1.get(\"g1\"), 5)\n        self.assertEqual(apportionment1.get(\"g2\"), 5)\n        self.assertEqual(sum(apportionment1.values()), 10)\n\n        # Scenario 2: More workers, not perfectly divisible\n        tasks_info2 = [(create_test_graph(\"gA\", 1), 10.0), (create_test_graph(\"gB\", 1), 20.0), (create_test_graph(\"gC\", 1), 30.0)]  # 3 tasks\n        apportionment2 = scheduler._apportion_workers_to_tasks(tasks_info2, 10, \"even\")\n        # Expected: 10 workers / 3 tasks. Base = 3. Remainder = 1.\n        # gA, gB, gC (sorted by ID)\n        # gA: 3 + 1 = 4\n        # gB: 3 + 0 = 3 (if gA gets remainder) -> this depends on sort order for remainder\n        # gC: 3 + 0 = 3\n        # Let's check the sum and individual counts based on sorted order of graph_ids.\n        self.assertEqual(sum(apportionment2.values()), 10)\n        counts = collections.Counter(apportionment2.values())\n        self.assertEqual(counts[4], 1)  # One task gets 4\n        self.assertEqual(counts[3], 2)  # Two tasks get 3\n\n        # Scenario 3: Equal tasks and workers\n        tasks_info3 = [(create_test_graph(f\"t{i}\", 1), 10.0) for i in range(5)]  # 5 tasks\n        apportionment3 = scheduler._apportion_workers_to_tasks(tasks_info3, 5, \"even\")\n        self.assertTrue(all(count == 1 for count in apportionment3.values()))\n        self.assertEqual(sum(apportionment3.values()), 5)\n\n        # Scenario 4: No tasks\n        apportionment4 = scheduler._apportion_workers_to_tasks([], 10, \"even\")\n        self.assertEqual(apportionment4, {})\n\n        # Scenario 5: No workers\n        apportionment5 = scheduler._apportion_workers_to_tasks(tasks_info1, 0, \"even\")\n        self.assertTrue(all(count == 0 for count in apportionment5.values()))\n\n    def test_apportion_workers_to_tasks_param_aware_strategy(self):\n        \"\"\"\n        Test _apportion_workers_to_tasks with 'param_aware' strategy.\n        \"\"\"\n        scheduler = TaskScheduler(num_physical_nodes=1, gpus_per_node=10)  # 10 workers\n\n        # Tasks sorted by size for easier verification of param_aware logic\n        task_large = create_test_graph(\"g_large\", 1, model_params=300)\n        task_medium = create_test_graph(\"g_medium\", 1, model_params=200)\n        task_small = create_test_graph(\"g_small\", 1, model_params=100)\n\n        tasks_info_param = [(task_large, 300.0), (task_medium, 200.0), (task_small, 100.0)]  # 3 tasks, total size 600\n\n        # 10 workers for 3 tasks. Each gets 1 initially. 7 remaining.\n        # param_aware will distribute remaining 7 workers one by one, cycling through tasks sorted by size (desc)\n        # g_large (300), g_medium (200), g_small (100)\n        # Initial: g_large:1, g_medium:1, g_small:1 (3 workers used)\n        # Remaining 7:\n        # 1. g_large (+1) -> 2\n        # 2. g_medium (+1) -> 2\n        # 3. g_small (+1) -> 2\n        # 4. g_large (+1) -> 3\n        # 5. g_medium (+1) -> 3\n        # 6. g_small (+1) -> 3\n        # 7. g_large (+1) -> 4\n        # Final: g_large: 4, g_medium: 3, g_small: 3\n        apportionment_param = scheduler._apportion_workers_to_tasks(tasks_info_param, 10, \"param_aware\")\n        self.assertEqual(apportionment_param.get(\"g_large\"), 4)\n        self.assertEqual(apportionment_param.get(\"g_medium\"), 3)\n        self.assertEqual(apportionment_param.get(\"g_small\"), 3)\n        self.assertEqual(sum(apportionment_param.values()), 10)\n\n        # Scenario: All tasks have zero size (should fall back to even)\n        task_zero1 = create_test_graph(\"g_zero1\", 1, model_params=0)\n        task_zero2 = create_test_graph(\"g_zero2\", 1, model_params=0)\n        tasks_info_zero = [(task_zero1, 0.0), (task_zero2, 0.0)]  # 2 tasks, 10 workers\n        apportionment_zero = scheduler._apportion_workers_to_tasks(tasks_info_zero, 10, \"param_aware\")\n        # Expected: fallback to 'even', so 5 workers per task\n        self.assertEqual(apportionment_zero.get(\"g_zero1\"), 5)\n        self.assertEqual(apportionment_zero.get(\"g_zero2\"), 5)\n        self.assertEqual(sum(apportionment_zero.values()), 10)\n\n    def test_schedule_no_tasks(self):\n        \"\"\"\n        Test scheduling when no tasks are provided.\n        \"\"\"\n        assignments = self.scheduler_default.schedule_and_assign_tasks([])\n        self.assertTrue(all(graph is None for graph in assignments.values()))\n        self.assertEqual(len(assignments), self.scheduler_default.num_workers)\n\n    def test_schedule_fewer_tasks_than_workers(self):\n        \"\"\"\n        Test scheduling with fewer tasks than available workers.\n        All tasks should be scheduled, and workers apportioned.\n        \"\"\"\n        scheduler = TaskScheduler(num_physical_nodes=2, gpus_per_node=2)  # 4 workers\n        task1 = create_test_graph(\"task1\", 2, model_params=\"100M\")  # Size 100M\n        task2 = create_test_graph(\"task2\", 1, model_params=\"50M\")  # Size 50M\n        original_tasks = [task1, task2]  # 2 tasks\n\n        assignments = scheduler.schedule_and_assign_tasks(\n            original_tasks,\n            apportion_strategy=\"even\",  # 4 workers / 2 tasks = 2 workers per task\n            consider_node_cohesion=True,\n            consider_node_load=True,\n            consider_rank_preference=True,\n        )\n\n        assigned_counts = collections.Counter(g.graph_id for g in assignments.values() if g)\n        # The graph IDs will be suffixed by discover_and_split_parallel_paths\n        # e.g., \"task1_final_1\", \"task2_final_1\" if they are irreducible\n        # For simple graphs, discover_and_split_parallel_paths returns them as is, then renames.\n\n        # We expect two unique graph IDs in assignments, each assigned to 2 workers\n        self.assertEqual(len(assigned_counts), 2)  # Two unique tasks were scheduled\n\n        num_workers_per_task = {}\n        for worker, graph_obj in assignments.items():\n            if graph_obj:\n                num_workers_per_task.setdefault(graph_obj.graph_id, 0)\n                num_workers_per_task[graph_obj.graph_id] += 1\n\n        self.assertTrue(all(count == 2 for count in num_workers_per_task.values()))\n        self.assertEqual(sum(num_workers_per_task.values()), scheduler.num_workers)  # All workers used\n\n    def test_schedule_more_tasks_than_workers_raises_error(self):\n        \"\"\"\n        Test scheduling with more tasks than workers.\n        This should raise a ValueError as per the new requirement.\n        \"\"\"\n        scheduler = TaskScheduler(num_physical_nodes=1, gpus_per_node=2)  # 2 workers\n        tasks = [create_test_graph(\"t1\", 1, model_params=\"10M\"), create_test_graph(\"t2\", 1, model_params=\"20M\"), create_test_graph(\"t3\", 1, model_params=\"5M\")]  # 3 tasks\n\n        with self.assertRaisesRegex(ValueError, \"Cannot assign all tasks\"):\n            scheduler.schedule_and_assign_tasks(tasks)\n\n    def test_schedule_with_task_splitting(self):\n        \"\"\"\n        Test scheduling with a task that gets split into multiple irreducible subgraphs.\n        \"\"\"\n        scheduler = TaskScheduler(num_physical_nodes=1, gpus_per_node=4)  # 4 workers\n\n        # Create a graph that will be split (e.g., two parallel paths merging)\n        # A -> B --\\\n        #           C -> D\n        # E -> F --/\n        # discover_and_split_parallel_paths should create two subgraphs:\n        # 1. A -> B -> C -> D\n        # 2. E -> F -> C -> D\n        # (Actual splitting logic is in task_loader, we test its integration)\n        splittable_graph = TaskGraph(graph_id=\"splittable\")\n        splittable_graph.add_nodes(\n            [\n                Node(node_id=\"A\", node_type=NodeType.DATA_LOAD),\n                Node(node_id=\"B\", node_type=NodeType.COMPUTE, dependencies=[\"A\"]),\n                Node(node_id=\"E\", node_type=NodeType.DATA_LOAD),  # Another entry for a parallel path\n                Node(node_id=\"F\", node_type=NodeType.COMPUTE, dependencies=[\"E\"]),\n                Node(node_id=\"C\", node_type=NodeType.COMPUTE, dependencies=[\"B\", \"F\"]),  # Merge point\n                Node(node_id=\"D\", node_type=NodeType.MODEL_TRAIN, dependencies=[\"C\"], config={\"model_params\": \"10B\"}),\n            ]\n        )\n        splittable_graph.build_adjacency_lists()\n        self.assertTrue(splittable_graph.validate_graph()[0], \"Test splittable graph is invalid\")\n\n        assignments = scheduler.schedule_and_assign_tasks(\n            [splittable_graph],\n            apportion_strategy=\"even\",  # 4 workers / 2 subgraphs = 2 workers per subgraph\n            consider_node_cohesion=True,\n        )\n\n        # Count how many workers are assigned to each unique *scheduled* subgraph ID\n        scheduled_subgraph_worker_counts = collections.defaultdict(int)\n        for graph_obj in assignments.values():\n            if graph_obj:\n                scheduled_subgraph_worker_counts[graph_obj.graph_id] += 1\n\n        # We expect two distinct subgraphs to be scheduled (e.g., \"splittable_final_1\", \"splittable_final_2\")\n        self.assertEqual(len(scheduled_subgraph_worker_counts), 2, \"Should schedule two subgraphs after splitting.\")\n        # Each of these two subgraphs should get 2 workers\n        for subgraph_id, count in scheduled_subgraph_worker_counts.items():\n            self.assertEqual(count, 2, f\"Subgraph {subgraph_id} should have 2 workers.\")\n\n        self.assertEqual(sum(scheduled_subgraph_worker_counts.values()), scheduler.num_workers, \"All workers should be utilized.\")\n\n    def test_placement_logic_node_cohesion(self):\n        \"\"\"\n        Test placement logic focusing on node cohesion.\n        If a task gets multiple workers, they should ideally be on the same physical node if cohesion is enabled.\n        \"\"\"\n        # 1 physical node, 4 GPUs. So all workers for a task *must* be on node 0.\n        scheduler = TaskScheduler(num_physical_nodes=1, gpus_per_node=4)\n        task1 = create_test_graph(\"task_cohesive\", 1, model_params=\"100M\")\n\n        # Schedule task1, which should get all 4 workers (apportion_strategy='even' or 'param_aware' with one task)\n        assignments = scheduler.schedule_and_assign_tasks(\n            [task1],\n            apportion_strategy=\"even\",\n            consider_node_cohesion=True,\n            consider_node_load=True,  # Load won't matter much with 1 node\n            consider_rank_preference=True,\n        )\n\n        assigned_workers_for_task1 = []\n        for worker_rank, graph_obj in assignments.items():\n            if graph_obj and \"task_cohesive\" in graph_obj.graph_id:  # Check for the scheduled subgraph from task1\n                assigned_workers_for_task1.append(worker_rank)\n\n        self.assertEqual(len(assigned_workers_for_task1), 4, \"Task1 should get 4 workers.\")\n        # All workers (0, 1, 2, 3) are on physical node 0 (worker_rank // gpus_per_node)\n        self.assertTrue(all((rank // scheduler.gpus_per_node) == 0 for rank in assigned_workers_for_task1))\n\n        # More complex: 2 nodes, 2 GPUs/node (4 workers total). 1 task needing 3 workers.\n        scheduler_2n2g = TaskScheduler(num_physical_nodes=2, gpus_per_node=2)\n        task_needs_3 = create_test_graph(\"task_3w\", 1, model_params=\"200M\")\n        # With 4 workers and 1 task, task_3w gets 4 workers. Let's adjust.\n        # We need to control apportionment. Create 2 tasks, one big, one small.\n        task_big = create_test_graph(\"task_big\", 1, model_params=\"300M\")  # Should get more workers\n        task_small = create_test_graph(\"task_small\", 1, model_params=\"10M\")  # Should get fewer\n\n        # 4 workers, 2 tasks. 'param_aware'\n        # task_big (300), task_small (10)\n        # Initial: task_big:1, task_small:1 (2 workers used)\n        # Remaining 2:\n        # 1. task_big (+1) -> 2\n        # 2. task_big (+1) -> 3 (if cycling favors largest first for all remainders)\n        # Actually, it's 1. task_big (+1) -> 2, 2. task_small (+1) -> 2 if cycling through all tasks\n        # Let's re-check _apportion_workers_to_tasks for param_aware with 2 tasks, 4 workers:\n        # Initial: big:1, small:1. Remaining: 2.\n        # Sorted by size: [big, small]\n        # 1. big gets +1 -> big:2, small:1\n        # 2. small gets +1 -> big:2, small:2  <-- This is how current param_aware works (cycles)\n        # To force one task to get 3, let's use 3 workers for 1 task.\n        scheduler_1n3g = TaskScheduler(num_physical_nodes=1, gpus_per_node=3)\n        assignments_3w = scheduler_1n3g.schedule_and_assign_tasks(\n            [task_big],  # Only one task\n            apportion_strategy=\"even\",  # Will get all 3 workers\n            consider_node_cohesion=True,\n        )\n        assigned_workers_for_task_big = []\n        for r, g in assignments_3w.items():\n            if g and \"task_big\" in g.graph_id:\n                assigned_workers_for_task_big.append(r)\n        self.assertEqual(len(assigned_workers_for_task_big), 3)\n        self.assertTrue(all((rank // scheduler_1n3g.gpus_per_node) == 0 for rank in assigned_workers_for_task_big))\n\n    def test_placement_no_node_load_no_rank_preference(self):\n        \"\"\"\n        Test placement when node load and rank preference are disabled.\n        Cohesion should still work if enabled. Placement might be less deterministic for tie-breaking.\n        \"\"\"\n        scheduler = TaskScheduler(num_physical_nodes=2, gpus_per_node=2)  # 4 workers (0,1 on node 0; 2,3 on node 1)\n        task1 = create_test_graph(\"t1\", 1, model_params=\"100M\")\n        task2 = create_test_graph(\"t2\", 1, model_params=\"100M\")\n        # 2 tasks, 4 workers. 'even' -> 2 workers per task.\n\n        assignments = scheduler.schedule_and_assign_tasks(\n            [task1, task2],\n            apportion_strategy=\"even\",\n            consider_node_cohesion=True,  # Cohesion is on\n            consider_node_load=False,  # Load balancing is off\n            consider_rank_preference=False,  # Rank preference is off\n        )\n\n        # Verify all workers are assigned\n        self.assertEqual(len([g for g in assignments.values() if g is not None]), 4)\n\n        # Check that each task got 2 workers\n        worker_counts = collections.defaultdict(int)\n        task_assignments = collections.defaultdict(list)\n        for worker, graph in assignments.items():\n            if graph:\n                # Extract base name (e.g., \"t1\" from \"t1_final_1\")\n                base_graph_id = scheduler._get_original_graph_id(graph) if \"_final_\" in graph.graph_id else graph.graph_id\n                worker_counts[base_graph_id] += 1\n                task_assignments[base_graph_id].append(worker)\n\n        self.assertEqual(worker_counts.get(\"t1\"), 2)\n        self.assertEqual(worker_counts.get(\"t2\"), 2)\n\n        # With cohesion ON, the 2 workers for t1 should be on the same node.\n        # And the 2 workers for t2 should be on the same node.\n        # Node 0: GPUs 0, 1. Node 1: GPUs 2, 3.\n\n        # Example: t1_final_1 assigned to workers on node 0 (0,1)\n        #          t2_final_1 assigned to workers on node 1 (2,3)\n        # This is one possible cohesive assignment.\n\n        # We need to find which task is which after potential renaming by discover_and_split\n        scheduled_task_ids = list(task_assignments.keys())\n        if not (len(scheduled_task_ids) == 2 and all(\"_final_\" in tid for tid in scheduled_task_ids)):\n            # If splitting didn't happen as expected, this test might be flawed.\n            # However, for simple graphs, they are returned as is and then renamed.\n            pass\n\n        for task_id_key in task_assignments:  # task_id_key will be like \"t1_final_1\"\n            assigned_ranks = task_assignments[task_id_key]\n            self.assertEqual(len(assigned_ranks), 2, f\"Task {task_id_key} should have 2 workers\")\n            # Check cohesion: both workers for this task should be on the same physical node\n            physical_node_for_task = assigned_ranks[0] // scheduler.gpus_per_node\n            self.assertTrue(all((r // scheduler.gpus_per_node) == physical_node_for_task for r in assigned_ranks), f\"Workers for task {task_id_key} are not on the same physical node: {assigned_ranks}\")\n\n        # Check that the two tasks are on different physical nodes if possible (due to even distribution of tasks)\n        # This is a secondary effect of load balancing if it were on, but here, it's about filling nodes.\n        node_indices_used_by_tasks = set()\n        for task_id_key in task_assignments:\n            node_indices_used_by_tasks.add(task_assignments[task_id_key][0] // scheduler.gpus_per_node)\n\n        # If there are enough nodes for each task to be on a separate node, they should be.\n        if scheduler.num_physical_nodes >= len(task_assignments):\n            self.assertEqual(len(node_indices_used_by_tasks), len(task_assignments), \"Tasks should be on different physical nodes if possible.\")\n\n\nif __name__ == \"__main__\":\n    unittest.main(argv=[\"first-arg-is-ignored\"], exit=False, verbosity=2)\n"
  }
]